# Written by Bram Cohen
# see LICENSE.txt for license information

from cStringIO import StringIO
from binascii import b2a_hex
from socket import error as socketerror
import time
#I2P
from traceback import print_exc, print_stack
#/I2P

protocol_name = 'BitTorrent protocol'

def toint(s):
    return long(b2a_hex(s), 16)

def tobinary(i):
    return (chr(i >> 24) + chr((i >> 16) & 0xFF) + 
        chr((i >> 8) & 0xFF) + chr(i & 0xFF))

# header, reserved, download id, my id, [length, message]

class Connection:
    def __init__(self, Encoder, connection, id, is_local, log):
        self.log = log
        self.encoder = Encoder
        self.connection = connection
        self.id = id
        self.locally_initiated = is_local
        self.complete = False
        self.closed = False
        self.buffer = StringIO()
        self.next_len = 1
        self.next_func = self.read_header_len
        self.time_connected = time.time()
        
        if self.locally_initiated:
            connection.write(chr(len(protocol_name)) + protocol_name + 
                (chr(0) * 8) + self.encoder.download_id)
            if self.id is not None:
                connection.write(self.encoder.my_id)

    #g3rmz
    def get_address(self):
        return self.connection.get_address()
       
    def get_have_bitfield(self):
        return self.connection.get_have_bitfield()
    #end g3rmz
    
    def get_ip(self):
        return self.connection.get_ip()

    def get_id(self):
        return self.id

    def is_locally_initiated(self):
        return self.locally_initiated

    def is_flushed(self):
        return self.connection.is_flushed()

    def read_header_len(self, s):
        if ord(s) != len(protocol_name):
            return None
        return len(protocol_name), self.read_header

    def read_header(self, s):
        if s != protocol_name:
            return None
        return 8, self.read_reserved

    def read_reserved(self, s):
        return 20, self.read_download_id

    def read_download_id(self, s):
        if s != self.encoder.download_id:
            return None
        if not self.locally_initiated:
            self.connection.write(chr(len(protocol_name)) + protocol_name + 
                (chr(0) * 8) + self.encoder.download_id + self.encoder.my_id)
        return 20, self.read_peer_id

    def read_peer_id(self, s):
        if not self.id:
            if s == self.encoder.my_id:
                return None
            for v in self.encoder.connections.values():
                if s and v.id == s:
                    return None
            self.id = s
            if self.locally_initiated:
                self.connection.write(self.encoder.my_id)
            else:
                self.encoder.everinc = True
        else:
            if s != self.id:
                return None
        self.complete = True
        self.encoder.connecter.connection_made(self)
        return 4, self.read_len

    def read_len(self, s):
        l = toint(s)
        if l > self.encoder.max_len:
            return None
        return l, self.read_message

    def read_message(self, s):
        if s != '':
            self.encoder.connecter.got_message(self, s)
        return 4, self.read_len

    def read_dead(self, s):
        return None

    def close(self):
        if not self.closed:
            self.connection.close()
            self.sever()

    def sever(self):
        self.closed = True
        del self.encoder.connections[self.connection]
        if self.complete:
            self.encoder.connecter.connection_lost(self)

    def send_message(self, message):
        self.connection.write(tobinary(len(message)) + message)

    def data_came_in(self, s):
        while True:
            if self.closed:
                return
            i = self.next_len - self.buffer.tell()
            if i > len(s):
                self.buffer.write(s)
                return
            self.buffer.write(s[:i])
            s = s[i:]
            m = self.buffer.getvalue()
            self.buffer.reset()
            self.buffer.truncate()
            try:
                x = self.next_func(m)
            except:
                self.next_len, self.next_func = 1, self.read_dead
                raise
            if x is None:
                self.close()
                return
            self.next_len, self.next_func = x

class Encoder:
    def __init__(self, log, connecter, raw_server, my_id, max_len,
            schedulefunc, keepalive_delay, download_id, 
            max_initiate = 40):
        self.log = log
        self.raw_server = raw_server
        self.connecter = connecter
        self.my_id = my_id
        self.max_len = max_len
        self.schedulefunc = schedulefunc
        self.keepalive_delay = keepalive_delay
        self.download_id = download_id
        self.max_initiate = max_initiate
        self.everinc = False
        self.connections = {}
        self.spares = []
        schedulefunc(self.send_keepalives, keepalive_delay)

    def send_keepalives(self):
        self.schedulefunc(self.send_keepalives, self.keepalive_delay)
        for c in self.connections.values():
            if c.complete:
                c.send_message('')

    def start_connection(self, dns, id):
#I2P: dont connect to yourself
        if dns[0][-4:]=='.i2p':
            dest = dns[0][:-4]
        else:
            dest = dns[0]
        if dest == self.raw_server.mydest:
            return
#/I2P
        if id:
            if id == self.my_id:
                return
            for v in self.connections.values():
                if v.id == id:
                    return
        if len(self.connections) >= self.max_initiate:
            if len(self.spares) < self.max_initiate and dns not in self.spares:
                self.spares.append(dns)
            return
        #try:
        c = self.raw_server.start_connection(dns)
        self.connections[c] = Connection(self, c, id, True, self.log)
        #except socketerror:
        #    pass
    
    def _start_connection(self, dns, id):
        def foo(self=self, dns=dns, id=id):
            self.start_connection(dns, id)
        
        self.schedulefunc(foo, 0)
        
    def got_id(self, connection):
        for v in self.connections.values():
            if connection is not v and connection.id == v.id:
                connection.close()
                return
        self.connecter.connection_made(connection)

    def ever_got_incoming(self):
        return self.everinc

    def external_connection_made(self, connection):
        self.connections[connection] = Connection(self, 
            connection, None, False, self.log)

    def connection_flushed(self, connection):
#I2P: FIXME workaround for I2P
        self.log.add('Encrypter', 'connection_flushed on "'+connection.id+'"', 'DEBUG')
        if self.connections.has_key(connection):
            c = self.connections[connection]
            if c.complete:
                self.connecter.connection_flushed(c)
        else:
            self.log.add('Encrypter', 'Connection already went away in connection_flushed on "'+connection.id+'"', 'DEBUG')
            data = StringIO()
            print_stack(file = data)
            self.log.add('Encrypter', 'Data: "'+str(data.getvalue())+'"', 'DEBUG')
#/I2P

    def connection_lost(self, connection):
#I2P: FIXME workaround for I2P
        self.log.add('Encrypter', 'connection_lost on "'+connection.id+'"', 'DEBUG')
        if self.connections.has_key(connection):
            self.connections[connection].sever()
            while len(self.connections) < self.max_initiate and self.spares:
                self.start_connection(self.spares.pop(), None)
        else:
            self.log.add('Encrypter', 'Connection already went away in connection_lost on "'+connection.id+'"', 'DEBUG')
            data = StringIO()
            print_stack(file = data)
            self.log.add('Encrypter', 'Data: "'+str(data.getvalue())+'"', 'DEBUG')
#/I2P
        
    def data_came_in(self, connection, data):
#I2P: FIXME workaround for I2P
        if self.connections.has_key(connection):
            self.connections[connection].data_came_in(data)
        else:
             self.log.add('Encrypter', 'Connection already went away in data_came_in on "'+connection.id+'"', 'DEBUG')
#/I2P

