# Written by Bram Cohen
# see LICENSE.txt for license information

from bitfield import Bitfield
from binascii import b2a_hex
from CurrentRateMeasure import Measure

def toint(s):
    return long(b2a_hex(s), 16)

def tobinary(i):
    return (chr(i >> 24) + chr((i >> 16) & 0xFF) + 
        chr((i >> 8) & 0xFF) + chr(i & 0xFF))

CHOKE = chr(0)
UNCHOKE = chr(1)
INTERESTED = chr(2)
NOT_INTERESTED = chr(3)
# index
HAVE = chr(4)
# index, bitfield
BITFIELD = chr(5)
# index, begin, length
REQUEST = chr(6)
# index, begin, piece
PIECE = chr(7)
# index, begin, piece
CANCEL = chr(8)

class Connection:
    def __init__(self, connection, connecter):
        self.connection = connection
        self.connecter = connecter
        self.got_anything = False
        self.have = None
        self.last_msg_out = ''
        self.last_msg_in = ''

    #g3rmz
    def get_time_connected(self):
        return self.connection.time_connected
    
    def get_last_send(self):
        return self.last_msg_out

    def get_last_recv(self):
        return self.last_msg_in
        
    def get_address(self):
        return self.connection.get_address()

    def get_have_bitfield(self):
        if self.have != None:
            return self.have
        else:
            return None
    #end g3rmz

    def get_ip(self):
        return self.connection.get_ip()   
    
    def get_id(self):
        return self.connection.get_id()

    def close(self):
        self.connection.close()

    def is_flushed(self):
        if self.connecter.rate_capped:
            return False
        return self.connection.is_flushed()

    def is_locally_initiated(self):
        return self.connection.is_locally_initiated()

    def send_interested(self):
        self.connection.send_message(INTERESTED)

    def send_not_interested(self):
        self.connection.send_message(NOT_INTERESTED)

    def send_choke(self):
        self.last_msg_out = "choke"
        self.connection.send_message(CHOKE)

    def send_unchoke(self):
        self.last_msg_out = "unchoke"
        self.connection.send_message(UNCHOKE)

    def send_request(self, index, begin, length):
        self.connection.send_message(REQUEST + tobinary(index) + 
            tobinary(begin) + tobinary(length))

    def send_cancel(self, index, begin, length):
        self.last_msg_out = "cancel %d,%d" % (index, begin)
        self.connection.send_message(CANCEL + tobinary(index) + 
            tobinary(begin) + tobinary(length))

    def send_piece(self, index, begin, piece):
        self.last_msg_out = "piece %d,%d" % (index, begin)
        assert not self.connecter.rate_capped
        self.connecter._update_upload_rate(len(piece))
        self.connection.send_message(PIECE + tobinary(index) + 
            tobinary(begin) + piece)

    def send_bitfield(self, bitfield):
        self.connection.send_message(BITFIELD + bitfield)

    def send_have(self, index):
        self.connection.send_message(HAVE + tobinary(index))

    def get_upload(self):
        return self.upload

    def get_download(self):
        return self.download

class Connecter:
    def __init__(self, make_upload, downloader, choker, numpieces,
            totalup, max_upload_rate = 0, sched = None):
        self.downloader = downloader
        self.make_upload = make_upload
        self.choker = choker
        self.numpieces = numpieces
        self.max_upload_rate = max_upload_rate
        self.sched = sched
        self.totalup = totalup
        self.rate_capped = False
        self.connections = {}

    def _update_upload_rate(self, amount):
        self.totalup.update_rate(amount)
        if self.max_upload_rate > 0 and self.totalup.get_rate_noupdate() > self.max_upload_rate:
            self.rate_capped = True
            self.sched(self._uncap, self.totalup.time_until_rate(self.max_upload_rate))

    def _uncap(self):
        self.rate_capped = False
        while not self.rate_capped:
            up = None
            minrate = None
            for i in self.connections.values():
                if not i.upload.is_choked() and i.upload.has_queries() and i.connection.is_flushed():
                    rate = i.upload.get_rate()
                    if up is None or rate < minrate:
                        up = i.upload
                        minrate = rate
            if up is None:
                break
            up.flushed()
            if self.totalup.get_rate_noupdate() > self.max_upload_rate:
                break

    def change_max_upload_rate(self, newval):
        def foo(self=self, newval=newval):
            self._change_max_upload_rate(newval)
        self.sched(foo, 0);
        
    def _change_max_upload_rate(self, newval):
        self.max_upload_rate = newval
        #self._uncap()
        
    def how_many_connections(self):
        return len(self.connections)

    def connection_made(self, connection):
        c = Connection(connection, self)
        self.connections[connection] = c
        c.upload = self.make_upload(c)
        c.download = self.downloader.make_download(c)
        self.choker.connection_made(c)

    def connection_lost(self, connection):
        c = self.connections[connection]
        d = c.download
        del self.connections[connection]
        d.disconnected()
        self.choker.connection_lost(c)

    def connection_flushed(self, connection):
        self.connections[connection].upload.flushed()

    def got_message(self, connection, message):
        c = self.connections[connection]
        t = message[0]
        if t == BITFIELD and c.got_anything:
            connection.close()
            return
        c.got_anything = True
        if (t in [CHOKE, UNCHOKE, INTERESTED, NOT_INTERESTED] and 
                len(message) != 1):
            connection.close()
            return
        if t == CHOKE:
            c.last_msg_in = "choke"
            c.download.got_choke()
        elif t == UNCHOKE:
            c.last_msg_in = "unchoke"
            c.download.got_unchoke()
        elif t == INTERESTED:
            c.last_msg_in = "interested"
            c.upload.got_interested()
        elif t == NOT_INTERESTED:
            c.last_msg_in = "not interested"
            c.upload.got_not_interested()
        elif t == HAVE:
            if len(message) != 5:
                connection.close()
                return
            i = toint(message[1:])
            if i >= self.numpieces:
                connection.close()
                return
            c.download.got_have(i)
        elif t == BITFIELD:
            try:
                b = Bitfield(self.numpieces, message[1:])
            except ValueError:
                connection.close()
                return
            c.have = b
            c.download.got_have_bitfield(b)
        elif t == REQUEST:
            if len(message) != 13:
                connection.close()
                return
            i = toint(message[1:5])
            if i >= self.numpieces:
                connection.close()
                return
            c.upload.got_request(i, toint(message[5:9]), 
                toint(message[9:]))
        elif t == CANCEL:
            if len(message) != 13:
                connection.close()
                return
            i = toint(message[1:5])
            if i >= self.numpieces:
                connection.close()
                return
            c.upload.got_cancel(i, toint(message[5:9]), 
                toint(message[9:]))
            c.last_msg_in = "cancel %d" % i
        elif t == PIECE:
            if len(message) <= 9:
                connection.close()
                return
            i = toint(message[1:5])
            if i >= self.numpieces:
                connection.close()
                return
            start = toint(message[5:9])
            if c.download.got_piece(i, start, message[9:], ip=c.get_ip()):
                for co in self.connections.values():
                    co.send_have(i)
            c.last_msg_in = "piece %d,%d" % (i, start)
        else:
            connection.close()

