from bisect import insort
import string
from cStringIO import StringIO
from traceback import print_exc, print_stack
from errno import EINTR, EWOULDBLOCK
from time import time, sleep
from random import randrange
from BitTorrent.SAM import sam_bittorrent

import Logger
import SamBuffer

class SingleSocket:
    def __init__(self, raw_server, sock, handler, log):
        self.raw_server = raw_server
        # sock is (id, dest)
        self.id = sock[0]
        self.dest = sock[1]
        self.handler = handler
        self.log = log
        self.buffer = StringIO()
        self.last_hit = time()
        self.connected = False

    def connect(self):
        self.connected = True

    def close(self):
        """Initiate socket closing."""
        self.raw_server._sendClose(self.id)

    def get_ip(self):
        return self.dest

    def get_address(self):
        return self.dest

    def shutdown(self, val):
        self.raw_server._sendClose(self.id)

    def is_flushed(self):
        return self.buffer.tell() == 0

    def write(self, s):
        self.log.add('SamServer', 'Writing to buffer of ID "'+self.id+'", length: "'+str(len(s))+'"', 'DEBUG')
        self.buffer.write(s)

    def try_write(self):
        if self.connected:
            self.buffer.seek(0)
            while True:
                nmax = 32768
                block = self.buffer.read(nmax)
                if len(block)>0:
                    self.raw_server._sendData(self.id, block)
                else:
                    self.buffer.reset()
                    self.buffer.truncate()
                    return

def default_error_handler(x):
    print x

class SamServer:
    def __init__(self, doneflag, info_hash, sam, log, timeout_check_interval, timeout,
                 noisy = True, errorfunc = default_error_handler,
                 maxconnects = 55):
        # default sam config
        self.waittime = 1.0
        self.info_hash = info_hash
        #print '"'+self.info_hash+'"'
        self.SAM = sam
        self.log = log
        self.log.add('SamServer', 'INFO_HASH: "'+info_hash+'"', 'DEBUG')
        self.version = '1.0'
        self.tunnel_depth = '1'
        self.tunnel_number = '1'
        self.length_variance = '1'

        self.connected = self.SAM.status()
        self.handshaked = True
        self.bound = True

        self.samBuffer = SamBuffer.SamBuffer()
        self.outBuffer = []

        self.timeout_check_interval = timeout_check_interval
        self.timeout = timeout
        self.single_sockets = {}
        self.dead_from_write = []
        self.doneflag = doneflag
        self.noisy = noisy
        self.errorfunc = errorfunc
        self.maxconnects = maxconnects
        self.funcs = []
        self.unscheduled_tasks = []

        self.maxId = 0
        self.mydest = self.SAM.get_own_dest()
        self.add_task(self.scan_for_timeouts, timeout_check_interval)        

    def _sendClose(self, id):
        merk = self.SAM.shutdown(id)
        self.log.add('SamServer', 'CLOSE_ID: "'+id+'", RESPONSE: "'+merk+'"', 'INFO')      

    def _sendConnect(self, dest):
        response = self.SAM.connect(dest)
        self.maxId = self.maxId + 1
        self.log.add('SamServer', 'Connect: "'+dest[0:10]+'" response: "'+response+'"', 'INFO')
        return response

    def _sendData(self, id, msg):
        merk = self.SAM.send(id, msg)
        self.log.add('SamServer', 'SEND: ID '+id+' (LEN: '+str(len(msg))+') RESPONSE: "'+merk+'"', 'INFO')
        
    def _recv(self, id):
        data = self.SAM.recv(id)
        if not (data[0]=='ERROR' or data[0]=='SAM_ERROR'):
            if self.single_sockets.has_key(id):
                s = self.single_sockets.get(id)
                s.last_hit = time()
                self.log.add('SamServer', 'RECV: '+id+' LEN ('+str(len(data[1]))+') Response: "'+data[0]+'"', 'INFO')
                s.handler.data_came_in(s, data[1])
            else:
                self.log.add('SamServer', 'Lost socket? ('+id+')', 'INFO')
        
    def add_task(self, func, delay):
        self.unscheduled_tasks.append((func, delay))

    def scan_for_timeouts(self):
        self.add_task(self.scan_for_timeouts, self.timeout_check_interval)
        t = time() - self.timeout
        tokill = []
        for s in self.single_sockets.values():
            if s.last_hit < t:
                tokill.append(s)
        for k in tokill:
            self.log.add('SamServer', 'TIMEOUT: "'+k.id+'"', 'INFO')
            k.close()

    def start_connection(self, dns, handler = None):
        """Connect to a remote location. Returns a SingleSocket"""
        if handler == None:
            handler = self.handler
        # ignore the port
        dest = dns[0]
        # Remove Azureus style .i2p
        if len(dest) >= 256 and dest[-4:] == '.i2p':
            dest = dest[:-4]
        self.log.add('SamServer', 'LEN: "'+str(len(dest))+'"', 'DEBUG')
        Id = self._sendConnect(dest)
        s = SingleSocket(self, (Id, dest), self.handler, self.log)
        self.single_sockets[Id] = s
        return s

    def loopOnce(self, timeout=1.0):
        """Do a single poll loop"""
        #self.log.add('SamServer', 'WAIT: '+str(self.waittime), 'NORMAL') <- Don't outcomment this, if not for debugging purpose, because it heavily hogs your CPU!
        sleep(self.waittime)#Don't overdo this...
        sock = self.single_sockets.keys()
        moved_data = 0

        # see http://mail.python.org/pipermail/python-dev/2000-October/009671.html
        result = self.SAM.select(sock, sock, sock)
        if not (result=='ERROR' or result=='SAM_ERROR'):
            R = result[0]
            W = result[1]
            E = result[2]

            for i in range(0, len(W)):
                # setting connected
                s = self.single_sockets[W[i]]
                if s.connected==False:
                    self.log.add('SamServer', 'NEW - OUT: "'+W[i]+'"', 'INFO')
                    s.connected = True
                    s.connect()
                    
            for i in range(0, len(R)):
                # receive incoming data
                self._recv(R[i])
                if moved_data==0:
                    moved_data = 1
                
            for i in range(0, len(E)):
                # closing socket
                self.log.add('SamServer', 'CLOSED: "'+E[i]+'"', 'INFO')
                s = self.single_sockets[E[i]]
                self._close_socket(s)

            # handle incomming connections
            in_conns = self.SAM.get_in_conns(self.info_hash)
            for i in range(0, len(in_conns)):
                id = in_conns[i][0]
                dest = in_conns[i][1]
                if not self.single_sockets.has_key(in_conns[i][0])==True:
                    self.log.add('SamServer', 'NEW - IN: "'+id+'"', 'INFO')
                    s = SingleSocket(self, (id, dest), self.handler, self.log)
                    s.connect()
                    self.single_sockets[id] = s
                    self.handler.external_connection_made(s)

            # look if some socket has something to write
            for s in self.single_sockets.values():
                if s.connected and not s.is_flushed():
                    if moved_data==0:
                        moved_data = 1
                    s.try_write()
                    if s.is_flushed():
                        s.handler.connection_flushed(s)

            self.calculate_waittime(moved_data)

    def pop_unscheduled(self):
        try:
            while True:
                (func, delay) = self.unscheduled_tasks.pop()
                if not str(func)[0:40]=='<bound method DownloaderFeedback.display':
                    self.log.add('SamServer', 'FUNC: "'+str(func)+'" DELAY: "'+str(delay)+'"', 'DEBUG')
                insort(self.funcs, (time() + delay, func))
        except IndexError:
            pass

    def listen_forever(self, handler):
        self.handler = handler
        #try:
        self.connected = self.SAM.status()
        while not self.doneflag.isSet():
                #try:
            self.pop_unscheduled()
            if len(self.funcs) == 0:
                period = 2 ** 30
            else:
                period = self.funcs[0][0] - time()
            if period < 0:
                period = 0

            # network stuff
            self.loopOnce(period)

            # handle events
            while len(self.funcs) > 0 and self.funcs[0][0] <= time():
                garbage, func = self.funcs[0]
                del self.funcs[0]
                func()
                        #self._close_dead()
                        #self.handle_events(events)
                        #if self.doneflag.isSet():
                        #    return
                        #self._close_dead()
                #except KeyboardInterrupt:
                #    print_exc()
                #    return
        #finally:
        #    for ss in self.single_sockets.values():
        #        ss.close()
        self.log.add('SamServer', 'cleanly terminated', 'DEBUG')

    def _close_dead(self):
        while len(self.dead_from_write) > 0:
            old = self.dead_from_write
            self.dead_from_write = []
            for s in old:
                self._close_socket(s)

    def _close_socket(self, s):
        if s:
            id = s.id
            del self.single_sockets[id]
            s.connected = False
            self.SAM.cleanup(id)

            # uncomment to dump the stack to the logs
            #data = StringIO()
            #print_stack(file = data)
            #self.logger.debug(data.getvalue())

            s.handler.connection_lost(s)
        else:
            self.log.add('SamServer', 'wtf is this, a socket obj is "False"?! "'+id+'"', 'DEBUG')

    def calculate_waittime(self, moved_data):
        if moved_data==1 and self.waittime>0.01:
            self.waittime = self.waittime - (self.waittime/10.0)
        else:
            if self.waittime<1.0:
                self.waittime = self.waittime + (self.waittime/10.0)
