diff --git a/client.py b/client.py index 83cd2f6..d0fa087 100644 --- a/client.py +++ b/client.py @@ -4,23 +4,32 @@ __status__ = "Development" if (__name__ == "__main__"): import config - import socket as network - import message + import socket + import select + import sys - def init() -> None: - address = (config.host, config.port) + address = (config.host, config.port) - print("Starting connection to", address) + print("Starting connection to", address) - with network.socket(network.AF_INET, network.SOCK_STREAM) as socket: - socket.setblocking(False) - socket.connect_ex(address) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: + server_socket.setblocking(False) + server_socket.connect_ex(address) - while True: - message_body = input("Enter a message: ") + inputs = [sys.stdin, server_socket] - print("\n") - socket.send(message.serialize("Kayomn", message_body)) + while True: + readable_io, _, _ = select.select(inputs, [], []) - init() + for io in readable_io: + if (io == server_socket): + print(io.recv(4096)) + + elif (io == sys.stdin): + message = sys.stdin.readline() + + server_socket.send(message.encode("utf-8")) + sys.stdout.write("") + sys.stdout.write(message) + sys.stdout.flush() diff --git a/server.py b/server.py index 1c64d69..f763a81 100644 --- a/server.py +++ b/server.py @@ -3,102 +3,56 @@ __version__ = "0.0.1" __status__ = "Development" if (__name__ == "__main__"): + import threading import config - import traceback - import socket as network - import selectors - import message + import socket - class User: - def __init__(self, user_socket: network.socket): - self.socket = user_socket - self.connection, self.address = user_socket.accept() - self.selector = selectors.DefaultSelector() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket: + # Avoid bind() exception: OSError: [Errno 48] Address already in use + client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + client_socket.bind((config.host, config.port)) + client_socket.listen() - print("Accepted connection from", self.address) - self.connection.setblocking(False) + clients = [] - def close(self): - self.selector.close() - self.socket.close() - - def read(self): - try: - # Should be ready to read - data = self.socket.recv(4096) - except BlockingIOError: - # Resource temporarily unavailable (errno EWOULDBLOCK) - pass - else: - if data: - print(message.deserialize(data)) - else: - raise RuntimeError("Peer closed.") - - def write(self): - pass - - def run(self): - try: - while True: - events = self.selector.select(timeout=None) - - for (key, mask) in events: - try: - if (mask & selectors.EVENT_READ): - self.read() - - elif (mask & selectors.EVENT_WRITE): - self.write() - - except Exception: - print(f"Exception raised on {self.address}: {traceback.format_exc()}"), - self.socket.close() - - if not self.selector.get_map(): - break - - except KeyboardInterrupt: - print("caught keyboard interrupt, exiting") - - finally: - self.selector.close() - - def init(): - address = (config.host, config.port) - selector = selectors.DefaultSelector() - - with network.socket(network.AF_INET, network.SOCK_STREAM) as socket: - # Avoid bind() exception: OSError: [Errno 48] Address already in use - socket.setsockopt(network.SOL_SOCKET, network.SO_REUSEADDR, 1) - socket.bind(address) - socket.listen() - - print("Listening on", address) - - socket.setblocking(False) - selector.register(socket, selectors.EVENT_READ, data=None) - - users = [] + def spawn_client(client_connection, client_address): + client_connection.send("Welcome to this chatroom!".encode("utf-8")) try: - while True: - events = selector.select(timeout=None) + data = client_connection.recv(4096) - for (key, mask) in events: - if key.data is None: - user = User(key.fileobj) + print("test", data if data else "OOOPS") - users.append(user) - user.run() + while data: + message = "<" + client_address[0] + "> " + data - except KeyboardInterrupt: - print("Caught keyboard interrupt, exiting") + print(message) - finally: - for user in users: - user.close() + for client in clients: + if client != client_connection: + try: + client.send(message) - selector.close() + except: + client.close() + # if the link is broken, we remove the client + if client_connection in clients: + clients.remove(client_connection) - init() + data = client_connection.recv(4096) + + if client_connection in clients: + print("dead") + clients.remove(client_connection) + + except: + return + + while True: + connection, address = client_socket.accept() + + clients.append(connection) + + # prints the address of the user that just connected + print(address[0] + " connected") + threading.Thread(target=spawn_client, args=(connection, address)).start()