diff -uNr a/blatta/README.txt b/blatta/README.txt --- a/blatta/README.txt bc8aaa1f75e7830f0359f5a35958fc19b8bbdd316e7a97cf66bba38e0a616bd673751249ce6e9fa2aad7f05e32af3ae9b351e87cac8a679fb8d381510514d813 +++ b/blatta/README.txt 7ee0913ae7addd7e419ccde7a0f4e7c2348029ad00e1d6994da98430875358b91032ded251ee34c7917c81308be84e9042da9d54780b440aec23065a53c498ff @@ -30,3 +30,8 @@ 3. Use genkey to generate a key. 4. Add the key to the peer using the key command. 5. Add an address for the peer using the address command. + +NOTES: + +To run the unit tests, you'll need to run: +pip install mock diff -uNr a/blatta/blatta b/blatta/blatta --- a/blatta/blatta 50acef42c77e18fb23fa8157201ddba47717bb95ac85fe9d23393aba847a39665373cc6c98fdbb08096f321081980f421c17acd4490de8cb56d38bef637c0565 +++ b/blatta/blatta 87fb4a6177c042b76e8c0e34dc153bc2f574e78b90521124a44cf41e02f57aaf13ba15be758f3a67056f53064c8928f3da47dbc0cffb88019ae5456b6f06a9bf @@ -8,6 +8,7 @@ import sys import tempfile import time +import logging from lib.server import VERSION from lib.server import Server from lib.peer import Peer @@ -90,8 +91,11 @@ (options, args) = op.parse_args(argv[1:]) if options.channel_name is None: options.channel_name = "#pest" + log_format = "%(levelname)s %(asctime)s: %(message)s" if options.debug: - options.verbose = True + logging.basicConfig(level=logging.DEBUG, format=log_format, stream=sys.stdout) + else: + logging.basicConfig(level=logging.INFO, format=log_format, stream=sys.stdout) if options.irc_ports is None: options.irc_ports = "6697" if options.udp_port is None: @@ -139,7 +143,7 @@ try: server.start() except KeyboardInterrupt: - server.print_error("Interrupted.") + logging.error("Interrupted.") main(sys.argv) diff -uNr a/blatta/lib/client.py b/blatta/lib/client.py --- a/blatta/lib/client.py a35f64ee21532cc117fb36691ece37e2f523cb01855fb4eb8268ac745519f27fd36addb324d56d4893fb7a84fe70302630e53becf3571e250b3b578ca46abce3 +++ b/blatta/lib/client.py 98e1e99d7ea8fe523728ca8f5661b244f3769a09a44cfa8c40236fe5cb32bb2ee267c7ee2e2028d1e455b5fc3bae9efb53e91a457033133c4bb25162ebcf8e5e @@ -6,6 +6,8 @@ import os import base64 import traceback +import logging +from lib.state import State from lib.message import Message from lib.server import VERSION from funcs import * @@ -22,6 +24,7 @@ def __init__(self, server, socket): self.server = server + self.state = State.get_instance() self.socket = socket self.channels = {} # irc_lower(Channel name) --> Channel self.nickname = None @@ -37,15 +40,11 @@ else: self.__handle_command = self.__registration_handler - def is_addressed_to_me(self, message): - command = self.__parse_udp_message(message) - if command[0] == 'PRIVMSG': - if command[1][0][0] == '#' or command[1][0] == self.nickname: - return True - else: - return False - else: - return True + def message_from_station(self, msg): + targetname = self.server.channel_name if msg.command == BROADCAST else self.nickname + pest_prefix = msg.prefix if msg.prefix else msg.speaker + formatted_message = ":%s PRIVMSG %s :%s" % (pest_prefix, targetname, msg.body) + self.__writebuffer += formatted_message + "\r\n" def get_prefix(self): return "%s" % (self.nickname) @@ -68,30 +67,6 @@ def write_queue_size(self): return len(self.__writebuffer) - def __parse_udp_message(self, message): - data = " ".join(message.split()[1:]) + "\r\n" - lines = self.__linesep_regexp.split(data) - lines = lines[:-1] - commands = [] - for line in lines: - if not line: - # Empty line. Ignore. - continue - x = line.split(" ", 1) - command = x[0].upper() - if len(x) == 1: - arguments = [] - else: - if len(x[1]) > 0 and x[1][0] == ":": - arguments = [x[1][1:]] - else: - y = string.split(x[1], " :", 1) - arguments = string.split(y[0]) - if len(y) == 2: - arguments.append(y[1]) - commands.append([command, arguments]) - return commands[0] - def __parse_read_buffer(self): lines = self.__linesep_regexp.split(self.__readbuffer) self.__readbuffer = lines[-1] @@ -159,7 +134,6 @@ % self.nickname) self.reply("004 %s :%s blatta-%s o o" % (self.nickname, server.name, VERSION)) - self.send_lusers() self.send_motd() self.__handle_command = self.__command_handler @@ -182,32 +156,20 @@ if arguments[0] == "0": for (channelname, channel) in self.channels.items(): self.message_channel(channel, "PART", channelname, True) - self.channel_log(channel, "left", meta=True) server.remove_member_from_channel(self, channelname) self.channels = {} return channelnames = arguments[0].split(",") - if len(arguments) > 1: - keys = arguments[1].split(",") - else: - keys = [] - keys.extend((len(channelnames) - len(keys)) * [None]) - for (i, channelname) in enumerate(channelnames): + for channelname in channelnames: if irc_lower(channelname) in self.channels: continue if not valid_channel_re.match(channelname): self.reply_403(channelname) continue channel = server.get_channel(channelname) - if channel.key is not None and channel.key != keys[i]: - self.reply( - "475 %s %s :Cannot join channel (+k) - bad key" - % (self.nickname, channelname)) - continue channel.add_member(self) self.channels[irc_lower(channelname)] = channel self.message_channel(channel, "JOIN", channelname, True) - self.channel_log(channel, "joined", meta=True) if channel.topic: self.reply("332 %s %s :%s" % (self.nickname, channel.name, channel.topic)) @@ -218,7 +180,7 @@ % (self.nickname, channelname, " ".join(sorted(x - for x in self.server.state.get_peer_handles())))) + for x in self.state.get_peer_handles())))) self.reply("366 %s %s :End of NAMES list" % (self.nickname, channelname)) @@ -238,7 +200,7 @@ self.reply("323 %s :End of LIST" % self.nickname) def lusers_handler(): - self.send_lusers() + pass def mode_handler(): if len(arguments) < 1: @@ -268,8 +230,6 @@ self.message_channel( channel, "MODE", "%s +k %s" % (channel.name, key), True) - self.channel_log( - channel, "set channel key to %s" % key, meta=True) else: self.reply("442 %s :You're not on that channel" % targetname) @@ -279,8 +239,6 @@ self.message_channel( channel, "MODE", "%s -k" % channel.name, True) - self.channel_log( - channel, "removed channel key", meta=True) else: self.reply("442 %s :You're not on that channel" % targetname) @@ -313,9 +271,6 @@ self.reply("432 %s %s :Erroneous Nickname" % (self.nickname, newnick)) else: - for x in self.channels.values(): - self.channel_log( - x, "changed nickname to %s" % newnick, meta=True) oldnickname = self.nickname self.nickname = newnick server.client_changed_nickname(self, oldnickname) @@ -340,16 +295,23 @@ channel = server.get_channel(targetname) self.message_channel( channel, command, "%s :%s" % (channel.name, message)) - self.channel_log(channel, message) + # send the channel message to peers as well + self.server.station.infosec.message( + Message( + { + "speaker": self.nickname, + "command": BROADCAST, + "bounces": 0, + "body": message + })) else: - formatted_message = ":%s %s %s :%s" % (self.prefix, command, targetname, message) - self.server.peer_message(Message({ + self.server.station.infosec.message(Message({ "speaker": self.nickname, "handle": targetname, - "body": formatted_message, + "body": message, "bounces": 0, "command": DIRECT - }, self.server)) + })) if(client): client.message(formatted_message) @@ -372,7 +334,6 @@ self.message_channel( channel, "PART", "%s :%s" % (channelname, partmsg), True) - self.channel_log(channel, "left (%s)" % partmsg, meta=True) del self.channels[irc_lower(channelname)] server.remove_member_from_channel(self, channelname) @@ -405,8 +366,6 @@ self.message_channel( channel, "TOPIC", "%s :%s" % (channelname, newtopic), True) - self.channel_log( - channel, "set topic to %r" % newtopic, meta=True) else: if channel.topic: self.reply("332 %s %s :%s" @@ -464,17 +423,21 @@ def wot_handler(): if len(arguments) < 1: # Display the current WOT - peers = self.server.state.get_peers() + peers = self.state.get_peers() if len(peers) > 0: for peer in peers: - self.pest_reply("%s %s:%s" % (string.join(peer.handles, ","), peer.address, peer.port)) + if peer.address and peer.port: + address = "%s:%s" % (peer.address, peer.port) + else: + address = "
" + self.pest_reply("%s %s" % (string.join(peer.handles, ","), address)) else: self.pest_reply("WOT is empty") elif len(arguments) == 1: # Display all WOT data concerning the peer identified by HANDLE, # including all known keys, starting with the most recently used, for that peer. handle = arguments[0] - peer = self.server.state.get_peer_by_handle(handle) + peer = self.state.get_peer_by_handle(handle) if peer: self.pest_reply("keys:") for key in peer.keys: @@ -488,7 +451,7 @@ def peer_handler(): if len(arguments) == 1: try: - self.server.state.add_peer(arguments[0]) + self.state.add_peer(arguments[0]) self.pest_reply("added new peer %s" % arguments[0]) self.message(":%s JOIN %s" % (arguments[0], self.server.channel_name)) except: @@ -499,11 +462,11 @@ def unpeer_handler(): if len(arguments) == 1: try: - self.server.state.remove_peer(arguments[0]) + self.state.remove_peer(arguments[0]) self.pest_reply("removed peer %s" % arguments[0]) self.message(":%s PART %s" % (arguments[0], self.server.channel_name)) except Exception, e: - self.server.print_debug(e) + logging.debug(e) self.pest_reply("Error attempting to remove peer") else: self.pest_reply("Usage: UNPEER ") @@ -518,7 +481,7 @@ handle = arguments[0] key = arguments[1] try: - self.server.state.add_key(handle, key) + self.state.add_key(handle, key) self.pest_reply("added key: %s" % key) except: self.pest_reply("Error attempting to add key") @@ -528,23 +491,23 @@ self.pest_reply("Usage: UNKEY ") else: try: - self.server.state.remove_key(arguments[0]) + self.state.remove_key(arguments[0]) self.pest_reply("removed key: %s" % arguments[0]) except Exception, e: self.pest_reply("Error attempting to remove key") - self.server.print_debug(e) + logging.debug(e) def at_handler(): if len(arguments) == 0: - at = self.server.state.get_at() + at = self.state.get_at() elif len(arguments) == 1: handle = arguments[0] - at = self.server.state.get_at(handle) + at = self.state.get_at(handle) elif len(arguments) == 2: try: handle, address = arguments address_ip, port = string.split(address, ":") - self.server.state.update_address_table({"handle": handle, + self.state.update_at({"handle": handle, "address": address_ip, "port": port}, False) @@ -552,7 +515,7 @@ except Exception as ex: self.pest_reply("Error attempting to update address table") stack = traceback.format_exc() - print(stack) + logger.debug(stack) return elif len(arguments) > 2: self.pest_reply("Usage: AT [] [
]") @@ -599,12 +562,12 @@ except KeyError: self.reply("421 %s %s :Unknown command" % (self.nickname, command)) stack = traceback.format_exc() - print(stack) + logger.debug(stack) def socket_readable_notification(self): try: data = self.socket.recv(2 ** 10) - self.server.print_debug( + logging.debug( "[%s:%d] -> %r" % (self.host, self.port, data)) quitmsg = "EOT" except socket.error as x: @@ -621,7 +584,7 @@ def socket_writable_notification(self): try: sent = self.socket.send(self.__writebuffer) - self.server.print_debug( + logging.debug( "[%s:%d] <- %r" % ( self.host, self.port, self.__writebuffer[:sent])) self.__writebuffer = self.__writebuffer[sent:] @@ -630,7 +593,7 @@ def disconnect(self, quitmsg): self.message("ERROR :%s" % quitmsg) - self.server.print_info( + logging.info( "Disconnected connection from %s:%s (%s)." % ( self.host, self.port, quitmsg)) self.socket.close() @@ -654,31 +617,8 @@ def message_channel(self, channel, command, message, include_self=False): line = ":%s %s %s" % (self.prefix, command, message) - for client in channel.members: - if client != self or include_self: - client.message(line) - # send the channel message to peers as well - self.server.peer_message( - Message( - { - "speaker": self.nickname, - "command": BROADCAST, - "bounces": 0, - "body": line - }, self.server)) - - def channel_log(self, channel, message, meta=False): - if not self.server.logdir: - return - if meta: - format = "[%s] * %s %s\n" - else: - format = "[%s] <%s> %s\n" - timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC") - logname = channel.name.replace("_", "__").replace("/", "_") - fp = open("%s/%s.log" % (self.server.logdir, logname), "a") - fp.write(format % (timestamp, self.nickname, message)) - fp.close() + if include_self: + self.message(line) def message_related(self, msg, include_self=False): clients = set() @@ -691,10 +631,6 @@ for client in clients: client.message(msg) - def send_lusers(self): - self.reply("251 %s :There are %d users and 0 services on 1 server" - % (self.nickname, len(self.server.clients))) - def send_motd(self): server = self.server motdlines = server.get_motd_lines() diff -uNr a/blatta/lib/infosec.py b/blatta/lib/infosec.py --- a/blatta/lib/infosec.py 2bebeb6ee55f0941c567e114ee01352f18f96d02ac4e2f037b9ced7e71c219cc2ed51ac615d28e22244c5e07f136ec7ce1fb565c7217c62de4708aee4d8b05d5 +++ b/blatta/lib/infosec.py 2f8e9df6cf92a779900080f585a1b9873218d949be48299950c2b05a64fd5a3abd8ab927f5c5d4c5d51e7c3def2bab13a6ffe532907e681883730d60bd75b0e5 @@ -16,6 +16,7 @@ import random import os import pprint +import logging pp = pprint.PrettyPrinter(indent=4) PACKET_SIZE = 496 @@ -34,32 +35,66 @@ IGNORED = 4 class Infosec(object): - def __init__(self, server=None): - self.server = server + def __init__(self, state=None): + self.state = state - def get_message_bytes(self, message, peer=None): - try: - timestamp = message.timestamp - except: - timestamp = None - command = message.command - speaker = self._pad(message.speaker, MAX_SPEAKER_SIZE) + def message(self, message): + # if we are not rebroadcasting we need to set the timestamp - # if we are rebroadcasting we need to use the original timestamp + if message.timestamp == None: + message.original = True + message.timestamp = int(time.time()) + else: + message.original = False + + target_peer = (self.state.get_peer_by_handle(message.handle) + if message.command == DIRECT + else None) + + if target_peer and not target_peer.get_key(): + logging.debug("No key for peer associated with %s" % message.handle) + return + + if message.command == DIRECT and target_peer == None: + logging.debug("Aborting message: unknown handle: %s" % message.handle) + return + + message_bytes = self.get_message_bytes(message, target_peer) + if message.command != IGNORE: + message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest()) + logging.debug("generated message_hash: %s" % message_hash) + self.state.add_to_dedup_queue(message_hash) + self.state.log(message.speaker, message_bytes, target_peer) + + if message.command == DIRECT: + signed_packet_bytes = self.pack(target_peer, message, message_bytes) + target_peer.send(signed_packet_bytes) + elif message.command == BROADCAST or message.command == IGNORE: + for peer in self.state.get_keyed_peers(): + + # we don't want to send a broadcast back to the originator - if(timestamp == None): - int_ts = int(time.time()) + if message.peer and (peer.peer_id == message.peer.peer_id): + next + + signed_packet_bytes = self.pack(peer, message, message_bytes) + peer.send(signed_packet_bytes) else: - int_ts = timestamp + pass + + def get_message_bytes(self, message, peer=None): + timestamp = message.timestamp + command = message.command + speaker = self._pad(message.speaker, MAX_SPEAKER_SIZE) # let's generate the self_chain value from the last message or set it to zero if # there this is the first message if message.original: if command == DIRECT: - self_chain = self.server.state.get_last_message_hash(message.speaker, peer.peer_id) + self_chain = self.state.get_last_message_hash(message.speaker, peer.peer_id) elif command == BROADCAST: - self_chain = self.server.state.get_last_message_hash(message.speaker) + self_chain = self.state.get_last_message_hash(message.speaker) elif command == IGNORE: self_chain = "\x00" * 32 net_chain = "\x00" * 32 @@ -69,16 +104,19 @@ # pack message bytes - message_bytes = struct.pack(MESSAGE_PACKET_FORMAT, int_ts, self_chain, net_chain, speaker, message.body) + if message.command != IGNORE: + logging.debug("packing message bytes: %s" % message.body) + else: + logging.debug("packing rubbish message bytes: %s" % binascii.hexlify(message.body)) + + message_bytes = struct.pack(MESSAGE_PACKET_FORMAT, message.timestamp, self_chain, net_chain, speaker, message.body) return message_bytes - def pack(self, peer, message): + def pack(self, peer, message, message_bytes): key_bytes = base64.b64decode(peer.get_key()) signing_key = key_bytes[:32] cipher_key = key_bytes[32:] - message_bytes = self.get_message_bytes(message, peer) - # pack packet bytes nonce = self._generate_nonce(16) @@ -111,15 +149,15 @@ try: black_packet_bytes, signature_bytes = struct.unpack(BLACK_PACKET_FORMAT, black_packet) except: - self.server.print_error("Discarding malformed black packet from %s" % peer.get_key()) - return Message({ "error_code": MALFORMED_PACKET }, self.server) + logging.error("Discarding malformed black packet from %s" % peer.get_key()) + return Message({ "error_code": MALFORMED_PACKET }) # check signature signature_check_bytes = hmac.new(signing_key, black_packet_bytes, hashlib.sha384).digest() if(signature_check_bytes != signature_bytes): - return Message({ "error_code": INVALID_SIGNATURE }, self.server) + return Message({ "error_code": INVALID_SIGNATURE }) # try to decrypt black packet @@ -130,10 +168,27 @@ nonce, bounces, version, command, message_bytes = struct.unpack(RED_PACKET_FORMAT, red_packet_bytes) + # compute message_hash + + message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest()) + # unpack message - int_ts, self_chain, net_chain, speaker, message = struct.unpack(MESSAGE_PACKET_FORMAT, message_bytes) - speaker = speaker.strip() + int_ts, self_chain, net_chain, speaker, body = struct.unpack(MESSAGE_PACKET_FORMAT, message_bytes) + + # remove padding from speaker + + for index, byte in enumerate(speaker): + if byte == '\x00': + speaker = speaker[0:index] + break + + # remove padding from body + + for index, byte in enumerate(body): + if byte == '\x00': + body = body[0:index] + break # nothing to be done for an IGNORE command @@ -143,39 +198,26 @@ # check timestamp if(int_ts not in self._ts_range()): - return Message({ "error_code": STALE_PACKET }, self.server) - - # check for duplicates - - message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest()) - if(self.server.state.is_duplicate_message(message_hash)): - return Message({ "error_code": DUPLICATE_PACKET }, self.server) - else: - self.server.state.add_to_dedup_queue(message_hash) + return Message({ "error_code": STALE_PACKET }) # check self_chain if command == DIRECT: - self_chain_check = self.server.state.get_last_message_hash(speaker, peer.peer_id) + self_chain_check = self.state.get_last_message_hash(speaker, peer.peer_id) elif command == BROADCAST: - self_chain_check = self.server.state.get_last_message_hash(speaker) + self_chain_check = self.state.get_last_message_hash(speaker) self_chain_valid = (self_chain_check == self_chain) # log this message for use in the self_chain check - self.server.state.log(speaker, message_bytes, peer.peer_id if (command == DIRECT) else None) + self.state.log(speaker, message_bytes, peer if (command == DIRECT) else None) - # remove padding from message bytes - - for index, byte in enumerate(message): - if binascii.hexlify(byte) == "00": - unpadded_message = message[0:index] - break + # build message object - return Message({ + message = Message({ "peer": peer, - "body": unpadded_message.rstrip(), + "body": body.rstrip(), "timestamp": int_ts, "command": command, "speaker": speaker, @@ -183,12 +225,19 @@ "self_chain": self_chain, "net_chain": net_chain, "self_chain_valid": self_chain_valid, - "error_code": None - }, - self.server) + "message_hash": message_hash + }) + + # check for duplicates + + if(self.state.is_duplicate_message(message_hash)): + message.error_code = DUPLICATE_PACKET + return message + + return message def _pad(self, text, size): - return text.ljust(size) + return text.ljust(size, "\x00") def _ts_range(self): current_ts = int(time.time()) diff -uNr a/blatta/lib/message.py b/blatta/lib/message.py --- a/blatta/lib/message.py 0096d80e9d0c52787f1ad8c43d6b392c5e5434dfd04f62f41361365d858ea8f92b92f901ce4165e2b4fb4079b7aa9e857cf4296dee3eb48759cf63120f3975c5 +++ b/blatta/lib/message.py 0cf5cd14c7e157cf47cf2e4f0f9fd076e89a817a71b389aa8748a8472437e62a2ac35fa5f3b3ae9be39a2f15a926e5920571267bdf39b9ed6b4e0e4cbf9c2970 @@ -1,7 +1,7 @@ class Message(object): - def __init__(self, message, server=None): + def __init__(self, message): self.original = True - self.server = server + self.prefix = None self.handle = message.get("handle") self.peer = message.get("peer") self.body = message.get("body") @@ -13,5 +13,4 @@ self.net_chain = message.get("net_chain") self.self_chain_valid = message.get("self_chain_valid") self.error_code = message.get("error_code") - if server: - self.state = server.state + self.message_hash = message.get("message_hash") diff -uNr a/blatta/lib/peer.py b/blatta/lib/peer.py --- a/blatta/lib/peer.py e763bb836eba69aedebd4d4adfdd8820e1a173f16e1fd493ddd16bf6d41718155c20ee38073570d2aeb0144ea281b6fbcc5ddde48d5ca60cb01a0b08103dd1f5 +++ b/blatta/lib/peer.py c96da174ae6ceb0489ed2b50872c02e481fa7d14d09a747b387bcf7ec45f1ec1dac5326a3370951f20c8995c85fcc7a8b91bc333cedaf78b11345a3897c8f95d @@ -1,20 +1,21 @@ import socket -from infosec import Infosec from commands import IGNORE +from commands import DIRECT +from commands import BROADCAST + import sys import binascii import traceback +import logging class Peer(object): - def __init__(self, server, peer_entry): + def __init__(self, socket, peer_entry): self.handles = peer_entry["handles"] self.keys = peer_entry["keys"] self.peer_id = peer_entry["peer_id"] - self.server = server self.address = peer_entry["address"] self.port = peer_entry["port"] - self.socket = self.server.udp_server_socket - self.infosec = Infosec(server) + self.socket = socket def get_key(self): if len(self.keys) > 0: @@ -22,16 +23,16 @@ else: return None - def send(self, msg): - try: - if msg.command != IGNORE: - self.server.print_debug("packing message: %s" % msg.body) - signed_packet_bytes = self.infosec.pack(self, msg) - self.socket.sendto(signed_packet_bytes, (self.address, self.port)) - self.server.print_debug("[%s:%d] <- %s" % (self.address, - self.port, - binascii.hexlify(signed_packet_bytes)[0:16])) + def send(self, signed_packet_bytes): + if self.get_key() != None: + try: + self.socket.sendto(signed_packet_bytes, (self.address, self.port)) + logging.debug("[%s:%d] <- %s" % (self.address, + self.port, + binascii.hexlify(signed_packet_bytes)[0:16])) - except Exception as ex: - stack = traceback.format_exc() - print(stack) + except Exception as ex: + stack = traceback.format_exc() + logging.debug(stack) + else: + logging.debug("Discarding message to unknown handle or handle with no key: %s" % message.handle) diff -uNr a/blatta/lib/server.py b/blatta/lib/server.py --- a/blatta/lib/server.py 16e7971b6eab7483a4060d5cae5111dec2f61618a2022620343ef7aa3fcedee87cc6499c9f9978215c315fde958e70fa7810f50967e97dd299cd98842118c12d +++ b/blatta/lib/server.py 7f7198c51eb6b00321c1754f1675d907263bf600b8ef67b79641e4a763357fe73935b9f1534234d226f9564e2341aeaafb50bfcff5128bfe14404651b8a36ef2 @@ -1,4 +1,4 @@ -VERSION = "9988" +VERSION = "9987" import os import select @@ -8,29 +8,18 @@ import tempfile import time import string -import binascii -import hashlib import datetime +import sqlite3 from datetime import datetime +from funcs import * from lib.client import Client -from lib.state import State from lib.channel import Channel -from lib.infosec import PACKET_SIZE -from lib.infosec import MAX_BOUNCES -from lib.infosec import STALE_PACKET -from lib.infosec import DUPLICATE_PACKET -from lib.infosec import MALFORMED_PACKET -from lib.infosec import INVALID_SIGNATURE -from lib.infosec import IGNORED -from lib.infosec import Infosec -from lib.peer import Peer +from lib.station import Station from lib.message import Message -from funcs import * -from commands import BROADCAST -from commands import DIRECT -from commands import IGNORE +from lib.infosec import PACKET_SIZE import imp import pprint +import logging class Server(object): def __init__(self, options): @@ -40,18 +29,14 @@ self.password = options.password self.motdfile = options.motd self.verbose = options.verbose - self.debug = options.debug self.logdir = options.logdir self.chroot = options.chroot self.setuid = options.setuid self.statedir = options.statedir - self.infosec = Infosec(self) self.config_file_path = options.config_file_path - self.state = State(self, options.db_path) self.pp = pprint.PrettyPrinter(indent=4) - - if options.address_table_path != None: - self.state.import_at_and_wot(options.address_table_path) + self.db_path = options.db_path + self.address_table_path = options.address_table_path if options.listen: self.address = socket.gethostbyname(options.listen) @@ -61,8 +46,9 @@ self.name = socket.getfqdn(self.address)[:server_name_limit] self.channels = {} # irc_lower(Channel name) --> Channel instance. - self.clients = {} # Socket --> Client instance..peers = "" + self.client = None self.nicknames = {} # irc_lower(Nickname) --> Client instance. + if self.logdir: create_directory(self.logdir) if self.statedir: @@ -79,7 +65,7 @@ try: pid = os.fork() if pid > 0: - self.print_info("PID: %d" % pid) + logging.info("PID: %d" % pid) sys.exit(0) except OSError: sys.exit(1) @@ -113,19 +99,6 @@ else: return [] - def print_info(self, msg): - if self.verbose: - print(msg) - sys.stdout.flush() - - def print_debug(self, msg): - if self.debug: - print("%s %s" % (datetime.now(), msg)) - sys.stdout.flush() - - def print_error(self, msg): - sys.stderr.write("%s\n" % msg) - def client_changed_nickname(self, client, oldnickname): if oldnickname: del self.nicknames[irc_lower(oldnickname)] @@ -139,118 +112,26 @@ def remove_client(self, client, quitmsg): client.message_related(":%s QUIT :%s" % (client.prefix, quitmsg)) for x in client.channels.values(): - client.channel_log(x, "quit (%s)" % quitmsg, meta=True) x.remove_client(client) if client.nickname \ and irc_lower(client.nickname) in self.nicknames: del self.nicknames[irc_lower(client.nickname)] - del self.clients[client.socket] + self.client = None def remove_channel(self, channel): del self.channels[irc_lower(channel.name)] - def handle_udp_data(self, bytes_address_pair): - data = bytes_address_pair[0] - address = bytes_address_pair[1] - packet_info = (address[0], - address[1], - binascii.hexlify(data)[0:16]) - self.print_debug("[%s:%d] -> %s" % packet_info) - for peer in self.state.get_peers(): - if peer.get_key() != None: - message = self.infosec.unpack(peer, data) - error_code = message.error_code - if(error_code == None): - self.print_debug("[%s] -> %s" % (peer.handles[0], message.body)) - - self.conditionally_update_address_table(peer, message, address) - # send the message to all clients - for c in self.clients: - if (self.clients[c].is_addressed_to_me(message.body)): - self.clients[c].message(message.body) - # send the message to all other peers if it should be propagated - if(message.command == BROADCAST) and message.bounces < MAX_BOUNCES: - self.rebroadcast(peer, message) - return - elif error_code == STALE_PACKET: - self.print_debug("[%s:%d] -> stale packet: %s" % packet_info) - return - elif error_code == DUPLICATE_PACKET: - self.print_debug("[%s:%d] -> duplicate packet: %s" % packet_info) - return - elif error_code == MALFORMED_PACKET: - self.print_debug("[%s:%d] -> malformed packet: %s" % packet_info) - return - elif error_code == IGNORED: - self.conditionally_update_address_table(peer, message, address) - self.print_debug("[%s:%d] -> ignoring packet: %s" % packet_info) - return - elif error_code == INVALID_SIGNATURE: - pass - self.print_debug("[%s:%d] -> martian packet: %s" % packet_info) - - # we only update the address table if the speaker is same as peer - - def conditionally_update_address_table(self, peer, message, address): - try: - idx = peer.handles.index(message.speaker) - except: - idx = None - - if idx != None: - self.state.update_address_table({"handle": message.speaker, - "address": address[0], - "port": address[1] - }) - def peer_message(self, message): - message.original = True - if message.command == DIRECT: - peer = self.state.get_peer_by_handle(message.handle) - message_bytes = self.infosec.get_message_bytes(message, peer) - message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest()) - self.state.add_to_dedup_queue(message_hash) - - self.state.log(message.speaker, message_bytes, peer.peer_id) - if peer and (peer.get_key() != None): - peer.send(message) - else: - self.print_debug("Discarding message to unknown handle or handle with no key: %s" % message.handle) - else: - message.timestamp = int(time.time()) - message_bytes = self.infosec.get_message_bytes(message) - if message.command != IGNORE: - self.state.log(message.speaker, message_bytes) - message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest()) - self.state.add_to_dedup_queue(message_hash) - for peer in self.state.get_peers(): - if peer.get_key() != None: - peer.send(message) - else: - self.print_debug("Discarding message to handle with no key: %s" % message.handle) - - def rebroadcast(self, source_peer, message): - message.original = False - for peer in self.state.get_peers(): - if(peer.peer_id != source_peer.peer_id): - message.command = BROADCAST - message.bounces = message.bounces + 1 - peer.send(message) - - - def sendrubbish(self): - for socket in self.clients: - self.peer_message(Message({ - "speaker": self.clients[socket].nickname, - "command": IGNORE, - "bounces": 0, - "body": self.infosec.gen_rubbish_body() - }, self)) - def start(self): # Setup UDP first self.udp_server_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) self.udp_server_socket.bind((self.address, self.udp_port)) - self.print_info("Listening for Pest packets on udp port %d." % self.udp_port) + self.station = Station({ "socket": self.udp_server_socket, + "db_path": self.db_path, + "address_table_path": self.address_table_path + }) + self.station.start_embargo_queue_checking() + self.station.start_rubbish() + logging.info("Listening for Pest packets on udp port %d." % self.udp_port) serversockets = [] for port in self.irc_ports: @@ -259,51 +140,55 @@ try: s.bind((self.address, port)) except socket.error as e: - self.print_error("Could not bind port %s: %s." % (port, e)) + logging.error("Could not bind port %s: %s." % (port, e)) sys.exit(1) s.listen(5) serversockets.append(s) del s - self.print_info("Listening for IRC connections on port %d." % port) + logging.info("Listening for IRC connections on port %d." % port) if self.chroot: os.chdir(self.chroot) os.chroot(self.chroot) - self.print_info("Changed root directory to %s" % self.chroot) + logging.info("Changed root directory to %s" % self.chroot) if self.setuid: os.setgid(self.setuid[1]) os.setuid(self.setuid[0]) - self.print_info("Setting uid:gid to %s:%s" + logging.info("Setting uid:gid to %s:%s" % (self.setuid[0], self.setuid[1])) last_aliveness_check = time.time() while True: (inputready,outputready,exceptready) = select.select([self.udp_server_socket],[],[],0) (iwtd, owtd, ewtd) = select.select( - serversockets + [x.socket for x in self.clients.values()], - [x.socket for x in self.clients.values() - if x.write_queue_size() > 0], + serversockets + ([self.client.socket] if self.client else []), + [self.client.socket] if self.client and self.client.write_queue_size() > 0 else [], [], .2) for x in inputready: - if x == self.udp_server_socket: - bytes_address_pair = self.udp_server_socket.recvfrom(PACKET_SIZE) - self.handle_udp_data(bytes_address_pair) + if x == self.udp_server_socket: + bytes_address_pair = self.udp_server_socket.recvfrom(PACKET_SIZE) + self.station.embargo_queue_lock.acquire() + try: + self.station.handle_udp_data(bytes_address_pair) + except sqlite3.ProgrammingError as ex: + logging.error("sqlite3 concurrency problem") + self.station.embargo_queue_lock.release() for x in iwtd: - if x in self.clients: - self.clients[x].socket_readable_notification() + if self.client != None: + self.client.socket_readable_notification() else: (conn, addr) = x.accept() - self.clients[conn] = Client(self, conn) - self.print_info("Accepted connection from %s:%s." % ( + self.client = Client(self, conn) + self.station.client = self.client + logging.info("Accepted connection from %s:%s." % ( addr[0], addr[1])) for x in owtd: - if x in self.clients: # client may have been disconnected - self.clients[x].socket_writable_notification() + if self.client and x == self.client.socket: # client may have been disconnected + self.client.socket_writable_notification() now = time.time() if last_aliveness_check + 10 < now: - for client in self.clients.values(): - client.check_aliveness() - last_aliveness_check = now - self.sendrubbish() # Kludge to keep ephemeral port open when NATed + if self.client: + self.client.check_aliveness() + last_aliveness_check = now def create_directory(path): if not os.path.isdir(path): diff -uNr a/blatta/lib/state.py b/blatta/lib/state.py --- a/blatta/lib/state.py acd5eaffdba356d5b2b2e0ce494e3be8aed35ccf0b96f9605bfd73fd3f758286f1908d043274a5480ac02c2d270550e1b061b32c0856e521a4eaba2f9f6b29f3 +++ b/blatta/lib/state.py 4f78202d4744a3284c00c4aac9c055f4abae95eea1c51c4acd519a9723a990d4d1fc336254140f75b7d75995a781935a2a6250dd19c7e7610b6643365a47938f @@ -2,42 +2,53 @@ import sqlite3 import imp import hashlib +import logging from itertools import chain class State(object): - - def __init__(self, server, db_path): - self.server = server - self.conn = sqlite3.connect(db_path) - self.cursor = self.conn.cursor() - self.cursor.execute("create table if not exists at(handle_id integer,\ - address text not null,\ - port integer not null,\ - active_at datetime default null,\ - updated_at datetime default current_timestamp,\ - unique(handle_id, address, port))") - - self.cursor.execute("create table if not exists wot(peer_id integer primary key)") - - self.cursor.execute("create table if not exists handles(handle_id integer primary key,\ - peer_id integer,\ - handle text,\ - unique(handle))") - - self.cursor.execute("create table if not exists keys(peer_id intenger,\ - key text,\ - used_at datetime default current_timestamp,\ - unique(key))") - - self.cursor.execute("create table if not exists logs(\ - handle text not null,\ - peer_id integer,\ - message_bytes blob not null,\ - created_at datetime default current_timestamp)") - - self.cursor.execute("create table if not exists dedup_queue(\ - hash text not null,\ - created_at datetime default current_timestamp)") + __instance = None + @staticmethod + def get_instance(socket=None, db_path=None): + if State.__instance == None: + State(socket, db_path) + return State.__instance + + def __init__(self, socket, db_path): + if State.__instance != None: + raise Exception("This class is a singleton") + else: + self.socket = socket + self.conn = sqlite3.connect(db_path, check_same_thread=False) + self.cursor = self.conn.cursor() + self.cursor.execute("create table if not exists at(handle_id integer,\ + address text not null,\ + port integer not null,\ + active_at datetime default null,\ + updated_at datetime default current_timestamp,\ + unique(handle_id, address, port))") + + self.cursor.execute("create table if not exists wot(peer_id integer primary key)") + + self.cursor.execute("create table if not exists handles(handle_id integer primary key,\ + peer_id integer,\ + handle text,\ + unique(handle))") + + self.cursor.execute("create table if not exists keys(peer_id intenger,\ + key text,\ + used_at datetime default current_timestamp,\ + unique(key))") + + self.cursor.execute("create table if not exists logs(\ + handle text not null,\ + peer_id integer,\ + message_bytes blob not null,\ + created_at datetime default current_timestamp)") + + self.cursor.execute("create table if not exists dedup_queue(\ + hash text not null,\ + created_at datetime default current_timestamp)") + State.__instance = self def get_at(self, handle=None): at = [] @@ -60,7 +71,7 @@ (handle_id,)).fetchone()[0] at.append({"handle": h, "address": "%s:%s" % (address, port), - "active_at": updated_at}) + "active_at": updated_at if updated_at else "no packets received from this address"}) return at @@ -69,6 +80,7 @@ self.conn.commit() result = self.cursor.execute("select hash from dedup_queue where hash=?", (message_hash,)).fetchone() + logging.debug("checking if %s is dupe" % message_hash) if(result != None): return True else: @@ -78,6 +90,7 @@ self.cursor.execute("insert into dedup_queue(hash)\ values(?)", (message_hash,)) + logging.debug("added %s to dedup" % message_hash) self.conn.commit() def get_last_message_hash(self, handle, peer_id=None): @@ -96,9 +109,14 @@ if message_bytes: return hashlib.sha256(message_bytes[0][:]).digest() else: - return "0" * 32 + return "\x00" * 32 + + def log(self, handle, message_bytes, peer=None): + if peer != None: + peer_id = peer.peer_id + else: + peer_id = None - def log(self, handle, message_bytes, peer_id=None): self.cursor.execute("insert into logs(handle, peer_id, message_bytes)\ values(?, ?, ?)", (handle, peer_id, buffer(message_bytes))) @@ -124,7 +142,7 @@ self.conn.commit() - def update_address_table(self, peer, set_active_at=True): + def update_at(self, peer, set_active_at=True): row = self.cursor.execute("select handle_id from handles where handle=?", (peer["handle"],)).fetchone() if row != None: @@ -196,7 +214,7 @@ (peer_id,)).fetchall())) def get_peer_handles(self): - handles = list(chain.from_iterable(self.cursor.execute("select handle from handles").fetchall())) + handles = self.listify(self.cursor.execute("select handle from handles").fetchall()) return handles def get_peers(self): @@ -209,6 +227,20 @@ peers.append(peer) return peers + def listify(self, results): + return list(chain.from_iterable(results)) + + def get_keyed_peers(self): + peer_ids = self.listify(self.cursor.execute("select peer_id from keys").fetchall()) + peers = [] + for peer_id in peer_ids: + handle = self.cursor.execute("select handle from handles where peer_id=?", (peer_id,)).fetchone()[0] + peer = self.get_peer_by_handle(handle) + if not (self.is_duplicate(peers, peer)): + peers.append(peer) + return peers + + def get_peer_by_handle(self, handle): handle_info = self.cursor.execute("select handle_id, peer_id from handles where handle=?", (handle,)).fetchone() @@ -219,18 +251,19 @@ address = self.cursor.execute("select address, port from at where handle_id=?\ order by updated_at desc limit 1", (handle_info[0],)).fetchone() - handles = list(chain.from_iterable(self.cursor.execute("select handle from handles where peer_id=?", - (handle_info[1],)).fetchall())) - keys = list(chain.from_iterable(self.cursor.execute("select key from keys where peer_id=?\ + handles = self.listify(self.cursor.execute("select handle from handles where peer_id=?", + (handle_info[1],)).fetchall()) + keys = self.listify(self.cursor.execute("select key from keys where peer_id=?\ order by used_at desc", - (handle_info[1],)).fetchall())) - return Peer(self.server, { + (handle_info[1],)).fetchall()) + return Peer(self.socket, { "handles": handles, "peer_id": handle_info[1], "address": address[0] if address else "", "port": address[1] if address else "", "keys": keys }) + def is_duplicate(self, peers, peer): for existing_peer in peers: if existing_peer.address == peer.address and existing_peer.port == peer.port: diff -uNr a/blatta/lib/station.py b/blatta/lib/station.py --- a/blatta/lib/station.py false +++ b/blatta/lib/station.py 9e41fdd532e857cec8e4d3407560d8570b8e6b7b713739e6d81622b0a6abcbe5a74a9e1ce70f192be2ce5f5c2d0b2374e433bcb7a1d83e322c24460adf92723a @@ -0,0 +1,196 @@ +import time +import threading +import binascii +import logging +import os +from lib.state import State +from lib.infosec import MAX_BOUNCES +from lib.infosec import STALE_PACKET +from lib.infosec import DUPLICATE_PACKET +from lib.infosec import MALFORMED_PACKET +from lib.infosec import INVALID_SIGNATURE +from lib.infosec import IGNORED +from lib.infosec import Infosec +from commands import IGNORE +from lib.message import Message +from commands import BROADCAST +from commands import DIRECT +from lib.peer import Peer + +RUBBISH_INTERVAL = 10 + +class Station(object): + def __init__(self, options): + self.client = None + self.state = State.get_instance(options["socket"], options["db_path"]) + if options.get("address_table_path") != None: + self.state.import_at_and_wot(options.get("address_table_path")) + self.infosec = Infosec(self.state) + self.embargo_queue = {} + self.embargo_queue_lock = threading.Lock() + + def start_embargo_queue_checking(self): + threading.Thread(target=self.check_embargo_queue).start() + + def start_rubbish(self): + pass + threading.Thread(target=self.send_rubbish).start() + + def handle_udp_data(self, bytes_address_pair): + data = bytes_address_pair[0] + address = bytes_address_pair[1] + packet_info = (address[0], + address[1], + binascii.hexlify(data)[0:16]) + logging.debug("[%s:%d] -> %s" % packet_info) + for peer in self.state.get_keyed_peers(): + message = self.infosec.unpack(peer, data) + error_code = message.error_code + if(error_code == None): + logging.debug("%s(%s) -> %s bounces: %d" % (message.speaker, peer.handles[0], message.body, message.bounces)) + self.conditionally_update_at(peer, message, address) + + # if this is a direct message, just deliver it and return + if message.command == DIRECT: + self.deliver(message) + return + + # if the speaker is in our wot, we need to check if the message is hearsay + if message.speaker in self.state.get_peer_handles(): + self.embargo(message) + return + + else: + # skip the embargo and deliver this message with appropriate simple hearsay labeling + message.prefix = "%s[%s]" % (message.speaker, peer.handles[0]) + self.deliver(message) + return + elif error_code == STALE_PACKET: + logging.debug("[%s:%d] -> stale packet: %s" % packet_info) + return + elif error_code == DUPLICATE_PACKET: + logging.debug("[%s:%d] -> duplicate packet: %s" % packet_info) + return + elif error_code == MALFORMED_PACKET: + logging.debug("[%s:%d] -> malformed packet: %s" % packet_info) + return + elif error_code == IGNORED: + self.conditionally_update_at(peer, message, address) + logging.debug("[%s:%d] -> ignoring packet: %s" % packet_info) + return + elif error_code == INVALID_SIGNATURE: + pass + logging.debug("[%s:%d] -> martian packet: %s" % packet_info) + + def deliver(self, message): + # add to duplicate queue + self.state.add_to_dedup_queue(message.message_hash) + + # send to the irc client + if self.client: + self.client.message_from_station(message) + + def embargo(self, message): + # initialize the key/value to empty array if not in the hash + # append message to array + if not message.message_hash in self.embargo_queue.keys(): + self.embargo_queue[message.message_hash] = [] + self.embargo_queue[message.message_hash].append(message) + + def check_embargo_queue(self): + # get a lock so other threads can't mess with the db or the queue + self.embargo_queue_lock.acquire() + self.check_for_immediate_messages() + self.flush_hearsay_messages() + + # release the lock + self.embargo_queue_lock.release() + + # continue the thread loop after interval + time.sleep(1) + threading.Thread(target=self.check_embargo_queue).start() + + def check_for_immediate_messages(self): + for key in dict(self.embargo_queue).keys(): + messages = self.embargo_queue[key] + + for message in messages: + + # if this is an immediate copy of the message + + if message.speaker in message.peer.handles: + + # clear the queue and deliver + + self.embargo_queue.pop(key, None) + self.deliver(message) + self.rebroadcast(message) + break + + + def flush_hearsay_messages(self): + # if we made it this far either we haven't found any immediate messages + # or we sent them all so we must deliver the remaining hearsay messages + # with the appropriate labeling + for key in dict(self.embargo_queue).keys(): + + # collect the source handles + handles = [] + messages = self.embargo_queue[key] + for message in messages: + handles.append(message.peer.handles[0]) + + # select the message with the lowest bounce count + message = sorted(messages, key=lambda m: m.bounces)[0] + + # clear the queue + self.embargo_queue.pop(key, None) + + # compute prefix + if len(messages) < 4: + message.prefix = "%s[%s]" % (message.speaker, "|".join(handles)) + else: + message.prefix = "%s[%d]" % (message.speaker, len(messages)) + + # deliver + self.deliver(message) + + # send the message to all other peers if it should be propagated + self.rebroadcast(message) + + + # we only update the address table if the speaker is same as peer + + def conditionally_update_at(self, peer, message, address): + if message.speaker in peer.handles: + self.state.update_at({ + "handle": message.speaker, + "address": address[0], + "port": address[1] + }) + + def rebroadcast(self, message): + if message.bounces < MAX_BOUNCES: + message.command = BROADCAST + message.bounces = message.bounces + 1 + self.infosec.message(message) + else: + logging.debug("[%s:%d] -> packet TTL expired: %s" % packet_info) + + + def send_rubbish(self): + logging.debug("sending rubbish...") + self.embargo_queue_lock.acquire() + try: + if self.client: + self.infosec.message(Message({ + "speaker": self.client.nickname, + "command": IGNORE, + "bounces": 0, + "body": self.infosec.gen_rubbish_body() + })) + except: + logging.error("Something went wrong attempting to send rubbish") + self.embargo_queue_lock.release() + time.sleep(RUBBISH_INTERVAL) + threading.Thread(target=self.send_rubbish).start() diff -uNr a/blatta/start_test_net.sh b/blatta/start_test_net.sh --- a/blatta/start_test_net.sh 10233fa2a74d0f92f3215b417140a9481f1263ceb7ca4486cca97d48e9c112a36a9b66cb4f2c99a553626dea431d6d8ae6d22735bd2535b8bde7ea964a1f0b21 +++ b/blatta/start_test_net.sh 24a5c19318989da9f79790107499e2ebda16bc5389b739e4e3ae686c3ff024317517203b9c5c3324ae1a391a63f94939e22c8de730e758ecbc6afee4f54e108d @@ -1,6 +1,6 @@ #!/bin/bash # start 3 servers on different ports -./blatta --debug --channel-name \#aleth --irc-port 6668 --udp-port 7778 --db-path a.db --address-table-path test_net_configs/a.py > logs/a & +./blatta --debug --channel-name \#aleth --irc-port 9968 --udp-port 7778 --db-path a.db --address-table-path test_net_configs/a.py > logs/a & ./blatta --debug --channel-name \#aleth --irc-port 6669 --udp-port 7779 --db-path b.db --address-table-path test_net_configs/b.py > logs/b & ./blatta --debug --channel-name \#aleth --irc-port 6670 --udp-port 7780 --db-path c.db --address-table-path test_net_configs/c.py > logs/c & diff -uNr a/blatta/test_net_configs/a.py b/blatta/test_net_configs/a.py --- a/blatta/test_net_configs/a.py 3276661a7529957fb3d7aac616f26be8de21d436ae5092c40662bc2fca472a6be7e460a2d8c76286a8d84bac1e8a8bf94b98e086c9df5a8a9389ff8c9efec8b9 +++ b/blatta/test_net_configs/a.py 27bfacb1a2f3d5c0c9947045e0dbf61d2822c0da84be7b9589b261d79fd3b9b2a845d354fc0310aae2841d5dc0b1d8ff22960d33d779dfbc6e0680bd33424d27 @@ -4,9 +4,9 @@ 'name': 'awt_b', 'port': 7779 }, - { 'address': 'localhost', - 'key': 'lT8/fYe/rQdReyavsTrVqInnLFCaU38o2ZAn5+r8uoFSSWgJelafikFELR9t6SJHMpFQvLmlAbF14nL2PfOAyA==', - 'name': 'awt_c', - 'port': 7780 - } +# { 'address': 'localhost', +# 'key': 'lT8/fYe/rQdReyavsTrVqInnLFCaU38o2ZAn5+r8uoFSSWgJelafikFELR9t6SJHMpFQvLmlAbF14nL2PfOAyA==', +# 'name': 'awt_c', +# 'port': 7780 +# } ] diff -uNr a/blatta/tests/__init__.py b/blatta/tests/__init__.py --- a/blatta/tests/__init__.py false +++ b/blatta/tests/__init__.py 85df4eea67226c8976c9484f97e06ee93506c5e43982babfe99ae2a075f4e1f43f99442cc80897e3ad2d8a409ae59320410347ca74c4317674b815082de8b240 @@ -0,0 +1 @@ +# This file can't be empty otherwise diff won't see it. diff -uNr a/blatta/tests/test_station.py b/blatta/tests/test_station.py --- a/blatta/tests/test_station.py false +++ b/blatta/tests/test_station.py 991e1ff9817a01d4320abdf30e09c890b5454c726e313a488dd6bedc9e8e663019a63caf9fc16979251f653beda05ececa4a806a7e3c299e6847a8b6fe11a6e4 @@ -0,0 +1,194 @@ +# https://stackoverflow.com/questions/1896918/running-unittest-with-typical-test-directory-structure +import unittest +import logging +from mock import Mock +from mock import patch + +from lib.station import Station + +class TestStation(unittest.TestCase): + def setUp(self): + logging.basicConfig(level=logging.DEBUG) + options = { + "clients": {"clientsocket": Mock()}, + "db_path": "tests/test.db", + "socket": Mock() + } + self.station = Station(options) + self.station.deliver = Mock() + self.station.rebroadcast = Mock() + self.station.rebroadcast.return_value = "foobar" + + def tearDown(self): + pass + + def test_embargo_bounce_ordering(self): + peer1 = Mock() + peer1.handles = ["a", "b"] + peer2 = Mock() + peer2.handles = ["c", "d"] + low_bounce_message = Mock() + low_bounce_message.peer = peer1 + low_bounce_message.bounces = 1 + low_bounce_message.message_hash = "messagehash" + high_bounce_message = Mock() + high_bounce_message.peer = peer2 + high_bounce_message.bounces = 2 + high_bounce_message.message_hash = "messagehash" + self.station.embargo_queue = { + "messagehash": [ + low_bounce_message, + high_bounce_message + ], + } + self.station.flush_hearsay_messages() + self.station.deliver.assert_called_once_with(low_bounce_message) + self.station.rebroadcast.assert_called_once_with(low_bounce_message) + + def test_immediate_message_delivered(self): + peer = Mock() + peer.handles = ["a", "b"] + message = Mock() + message.speaker = "a" + message.peer = peer + self.station.embargo_queue = { + "messagehash": [ + message + ], + } + self.station.check_for_immediate_messages() + self.station.deliver.assert_called_once_with(message) + self.station.rebroadcast.assert_called_once_with(message) + + def test_hearsay_message_not_delivered(self): + peer = Mock() + peer.handles = ["a", "b"] + message = Mock() + message.speaker = "c" + message.peer = peer + self.station.embargo_queue = { + "messagehash": [ + message + ], + } + self.station.check_for_immediate_messages() + self.station.deliver.assert_not_called() + + def test_embargo_queue_cleared(self): + peer = Mock() + peer.handles = ["a", "b"] + message = Mock() + message.speaker = "c" + message.peer = peer + self.station.embargo_queue = { + "messagehash": [ + message + ], + } + self.assertEqual(len(self.station.embargo_queue), 1) + self.station.flush_hearsay_messages() + self.assertEqual(len(self.station.embargo_queue), 0) + + def test_immediate_prefix(self): + peer = Mock() + peer.handles = ["a", "b"] + message = Mock() + message.speaker = "a" + message.prefix = None + message.peer = peer + self.station.embargo_queue = { + "messagehash": [ + message + ], + } + self.station.check_for_immediate_messages() + self.assertEqual(message.prefix, None) + + def test_simple_hearsay_prefix(self): + peer = Mock() + peer.handles = ["a", "b"] + message = Mock() + message.speaker = "c" + message.prefix = None + message.peer = peer + self.station.embargo_queue = { + "messagehash": [ + message + ], + } + self.station.flush_hearsay_messages() + self.assertEqual(message.prefix, "c[a]") + + def test_in_wot_hearsay_prefix_under_four(self): + peer1 = Mock() + peer1.handles = ["a", "b"] + peer2 = Mock() + peer2.handles = ["d", "e"] + peer3 = Mock() + peer3.handles = ["f", "g"] + message_via_peer1 = Mock() + message_via_peer1.speaker = "c" + message_via_peer1.prefix = None + message_via_peer1.peer = peer1 + message_via_peer1.bounces = 1 + message_via_peer2 = Mock() + message_via_peer2.speaker = "c" + message_via_peer2.prefix = None + message_via_peer2.peer = peer2 + message_via_peer2.bounces = 2 + message_via_peer3 = Mock() + message_via_peer3.speaker = "c" + message_via_peer3.prefix = None + message_via_peer3.peer = peer3 + message_via_peer3.bounces = 1 + self.station.embargo_queue = { + "messagehash": [ + message_via_peer1, + message_via_peer2, + message_via_peer3 + ], + } + self.station.flush_hearsay_messages() + self.station.deliver.assert_called_once_with(message_via_peer1) + self.assertEqual(message_via_peer1.prefix, "c[a|d|f]") + + def test_in_wot_hearsay_prefix_more_than_three(self): + peer1 = Mock() + peer1.handles = ["a", "b"] + peer2 = Mock() + peer2.handles = ["d", "e"] + peer3 = Mock() + peer3.handles = ["f", "g"] + peer4 = Mock() + peer4.handles = ["f", "g"] + message_via_peer1 = Mock() + message_via_peer1.speaker = "c" + message_via_peer1.prefix = None + message_via_peer1.peer = peer1 + message_via_peer1.bounces = 1 + message_via_peer2 = Mock() + message_via_peer2.speaker = "c" + message_via_peer2.prefix = None + message_via_peer2.peer = peer2 + message_via_peer2.bounces = 2 + message_via_peer3 = Mock() + message_via_peer3.speaker = "c" + message_via_peer3.prefix = None + message_via_peer3.peer = peer3 + message_via_peer3.bounces = 1 + message_via_peer4 = Mock() + message_via_peer4.speaker = "c" + message_via_peer4.prefix = None + message_via_peer4.peer = peer4 + message_via_peer4.bounces = 1 + self.station.embargo_queue = { + "messagehash": [ + message_via_peer1, + message_via_peer2, + message_via_peer3, + message_via_peer4 + ], + } + self.station.flush_hearsay_messages() + self.station.deliver.assert_called_once_with(message_via_peer1) + self.assertEqual(message_via_peer1.prefix, "c[4]")