# Written by Bram Cohen
# see LICENSE.txt for license information

import os

try:
    from mhash import MHASH, MHASH_SHA1
    mhash_flag = True
except:
    from sha import sha
    mhash_flag = False

from threading import Event
from bitfield import Bitfield
from bencode import bencode, bdecode
from md5 import md5

try:
    from mhash import MHASH, MHASH_SHA1
    mhash_flag = True
except:
    from sha import sha
    mhash_flag = False

def dummy_status(fractionDone = None, activity = None):
    pass

def dummy_data_flunked(size):
    pass

class StorageWrapper:
    def __init__(self, storage, request_size, hashes, 
            piece_size, finished, failed, 
            statusfunc = dummy_status, flag = Event(), check_hashes = True,
            data_flunked = dummy_data_flunked, pre_allocate = True, resumefile=None, errorfunc=None):

        #fast resume version number
        self.fastversion = 5
        # only needs to be changed if previous version compatibility is affected
        
        self.numpieces = len(hashes)
        self.resumefile = resumefile
        self.errorfunc = errorfunc
        self.storage = storage
        self.request_size = request_size
        self.hashes = hashes
        self.piece_size = piece_size
        self.data_flunked = data_flunked
        self.total_length = storage.get_total_length()
        self.amount_left = self.total_length
        if self.total_length <= piece_size * (self.numpieces - 1):
            raise ValueError, _('bad data from tracker - total too small')
        if self.total_length > piece_size * self.numpieces:
            raise ValueError, _('bad data from tracker - total too big')
        self.finished = finished
        self.failed = failed
        self.numactive = [0] * self.numpieces
        self.inactive_requests = [1] * self.numpieces
        self.amount_inactive = self.total_length
        self.endgame = False
        self.have = Bitfield(self.numpieces)
        self.waschecked = [check_hashes] * self.numpieces
        self.places = {}
        self.holes = []
        self.numchecked = 0
        
        if self.numpieces == 0:
            finished()
            return
        targets = {}
        total = self.numpieces
        for i in xrange(self.numpieces):
            if not self._waspre(i):
                targets.setdefault(hashes[i], []).append(i)
                total -= 1
                
        if total and check_hashes:
            statusfunc({"activity" : 'checking existing file', 
                "fractionDone" : 0})
        def markgot(piece, pos, self = self, check_hashes = check_hashes):
            self.places[piece] = pos
            self.have[piece] = True
            self.amount_left -= self._piecelen(piece)
            self.amount_inactive -= self._piecelen(piece)
            self.inactive_requests[piece] = None
            self.waschecked[piece] = check_hashes
        lastlen = self._piecelen(self.numpieces - 1)

        fastresume = False
        datachkOK = False
        rdata = self._LoadFastResume()
        td_data = self.storage.GetDateSizeData()

        try:
            if rdata != None:
                #check version of fast resume
                if rdata[2] == self.fastversion: 
                    #checksum verification
                    if mhash_flag:
                        checksum = MHASH(MHASH_SHA1, ( str(self.fixarray(rdata[0])) )).digest()
                    else:        
                        checksum = sha( str(self.fixarray(rdata[0])) ).digest()
                    if rdata[3] == checksum: 
                        #check that date size info is correct length
                        if len(rdata[1]) == len(td_data): 
                            #check date and size info to make sure files haven't been modified
                            for i in xrange(len(td_data)):
                                if rdata[1][i][0] == td_data[i][0] and rdata[1][i][1] == td_data[i][1]:
                                    datachkOK = True
                                else:
                                    datachkOK = False
                                    break

                        #mark the pieces that the user already has
                        if datachkOK and len(rdata[0]) == len(self.have):
                            for i in xrange(self.numpieces):
                                self.numchecked += 1
                                if rdata[0][i]:
                                    markgot(i, i)
                            fastresume = True
                        #resume data incorrect length
                        elif len(rdata[0]) != len(self.have):
                            self.errorfunc(_("Total resume piece size and total piece size did not match.  Rechecking hashes..."), -2)
                        #size/date info invalid - files have been modified
                        elif not datachkOK:
                            self.errorfunc(_("Date and size info in Fast Resume data does not match files on disk.  Rechecking hashes..."), -2)
                        #default - unknown error
                        else:
                            self.errorfunc(_("An unknown error occurred in fast resume"), -2)
                            self.errorfunc(_("Rechecking hashes..."), -1)
                    #checksum failed
                    else:
                        self.errorfunc(_("md5 checksum failed - resume data was corrupt"))
                        self.errorfunc(_("Rechecking hashes..."), -1)
                #incompatible version 
                else:
                    self.errorfunc(_("Current resume file (Version %s) is not compatibile with this version of Rufus fast resume (Version %s).  Rechecking hashes...") % (rdata[2], self.fastversion), -2)
        #something went very wrong
        except:
            self.errorfunc(_("An unknown exception has occurred in fast resume"), -2)
            self.errorfunc(_("Rechecking hashes..."), -1)

        for i in xrange(self.numpieces):
            if not self._waspre(i):
                self.holes.append(i)
            elif fastresume:
                break
            elif not check_hashes:
                markgot(i, i)
            else:
                    
                if mhash_flag:
                    sh = MHASH(MHASH_SHA1, (self.storage.read(piece_size * i, lastlen)))
                else:
                    sh = sha(self.storage.read(piece_size * i, lastlen))                    
                
                sp = sh.digest()
                sh.update(self.storage.read(piece_size * i + lastlen, self._piecelen(i) - lastlen))
                s = sh.digest()
                if s == hashes[i]:
                    markgot(i, i)
                elif targets.get(s) and self._piecelen(i) == self._piecelen(targets[s][-1]):
                    markgot(targets[s].pop(), i)
                elif not self.have[self.numpieces - 1] and sp == hashes[-1] and (i == self.numpieces - 1 or not self._waspre(self.numpieces - 1)):
                    markgot(self.numpieces - 1, i)
                else:
                    self.places[i] = i
                if flag.isSet():
                    return
            if not fastresume:
                self.numchecked += 1

            statusfunc({'fractionDone': float(i) / self.numpieces})

                
        if self.amount_left == 0:
            finished()

        #save resume data after initial hashcheck
        self.SaveFastResume()
        
        if (pre_allocate and check_hashes) or len(self.storage.files) > 1:
            self.reorder_pieces(flag, statusfunc)


    def reorder_pieces(self, doneflag, statusfunc):
        print 'reordering'
        statusfunc({"activity" : 'reordering', "fractionDone" : 0})
        nmoves = max(1, len(self.places))
        j = 0
        for i in xrange(self.numpieces):
            if doneflag.isSet():
                return

            if not self.places.has_key(i):
                self.places[i] = i
                j += 1

            elif self.places[i] != i:
                piece_dest = self.storage.read(self.piece_size * i, self._piecelen(i))
                piece_src = self.storage.read(self.piece_size * self.places[i], self._piecelen(i))

                self.storage.write(self.piece_size * i, piece_src)
                self.storage.write(self.piece_size * self.places[i], piece_dest)
                self.places[i] = i
                j += 1
            statusfunc({"fractionDone": float(j) / nmoves})

    def _waspre(self, piece):
        return self.storage.was_preallocated(piece * self.piece_size, self._piecelen(piece))

    def _piecelen(self, piece):
        if piece < self.numpieces - 1:
            return self.piece_size
        else:
            return self.total_length - piece * self.piece_size

    def get_amount_left(self):
        return self.amount_left

    def do_I_have_anything(self):
        return self.amount_left < self.total_length

    def _make_inactive(self, index):
        length = min(self.piece_size, self.total_length - self.piece_size * index)
        l = []
        x = 0
        while x + self.request_size < length:
            l.append((x, self.request_size))
            x += self.request_size
        l.append((x, length - x))
        self.inactive_requests[index] = l

    def is_endgame(self):
        return self.endgame

    #g3rmz
    def get_have_array(self): 
        return self.have.array
        
    def get_have_list(self):
        return self.have.tostring()

    def get_request_list(self):
        return self.inactive_requests

    def get_nhashes(self):
        return self.numpieces
    
    def get_files(self):
        return self.storage.files

    def do_I_have(self, index):
        return self.have[index]

    def do_I_have_requests(self, index):
        return not not self.inactive_requests[index]

    def new_request(self, index):
        # returns (begin, length)
        if self.inactive_requests[index] == 1:
            self._make_inactive(index)
        self.numactive[index] += 1
        rs = self.inactive_requests[index]
        r = min(rs)
        rs.remove(r)
        self.amount_inactive -= r[1]
        if self.amount_inactive == 0:
            self.endgame = True
        return r

    def piece_came_in(self, index, begin, piece, ip=None):
        try:
            return self._piece_came_in(index, begin, piece, ip)
        except IOError, e:
            self.failed('IO Error ' + str(e))
            return True

    def _piece_came_in(self, index, begin, piece, ip=None):
        if not self.places.has_key(index):
            n = self.holes.pop(0)
            if self.places.has_key(n):
                oldpos = self.places[n]
                old = self.storage.read(self.piece_size * oldpos, self._piecelen(n))  
                    
                if mhash_flag:
                    if self.have[n] and MHASH(MHASH_SHA1, (old)).digest() != self.hashes[n]:
                        self.failed(_('data corrupted on disk - maybe you have two copies running?'))
                        return True
                    self.storage.write(self.piece_size * n, old)
                    self.places[n] = n
                    if index == oldpos or index in self.holes:
                        self.places[index] = oldpos
                    else:
                        for p, v in self.places.items():
                            if v == index:
                                break
                        self.places[index] = index
                        self.places[p] = oldpos
                        old = self.storage.read(self.piece_size * index, self.piece_size)
                        self.storage.write(self.piece_size * oldpos, old)
                else:                    
                    if self.have[n] and sha(old).digest() != self.hashes[n]:   
                        self.failed(_('data corrupted on disk - maybe you have two copies running?'))
                        return True
                    self.storage.write(self.piece_size * n, old)
                    self.places[n] = n
                    if index == oldpos or index in self.holes:
                        self.places[index] = oldpos
                    else:
                        for p, v in self.places.items():
                            if v == index:
                                break
                        self.places[index] = index
                        self.places[p] = oldpos
                        old = self.storage.read(self.piece_size * index, self.piece_size)
                        self.storage.write(self.piece_size * oldpos, old)
                        
            elif index in self.holes or index == n:
                if not self._waspre(n):
                    self.storage.write(self.piece_size * n, self._piecelen(n) * chr(0xFF))
                self.places[index] = n
            else:
                for p, v in self.places.items():
                    if v == index:
                        break
                self.places[index] = index
                self.places[p] = n
                old = self.storage.read(self.piece_size * index, self._piecelen(n))
                self.storage.write(self.piece_size * n, old)
        self.storage.write(self.places[index] * self.piece_size + begin, piece)
        self.numactive[index] -= 1
        if not self.inactive_requests[index] and not self.numactive[index]:        
             
            if mhash_flag:
                if MHASH(MHASH_SHA1, (self.storage.read(self.piece_size * self.places[index], self._piecelen(index)))).digest() == self.hashes[index]:
                    self.have[index] = True
                    self.inactive_requests[index] = None
                    self.waschecked[index] = True
                    self.amount_left -= self._piecelen(index)
                    if self.amount_left == 0:
                        self.finished()
                else:
                    self.data_flunked(self._piecelen(index), index=index, ip=ip)
                    self.inactive_requests[index] = 1
                    self.amount_inactive += self._piecelen(index)
                    return False 
            else:               
                if sha(self.storage.read(self.piece_size * self.places[index], self._piecelen(index))).digest() == self.hashes[index]:               
                    self.have[index] = True
                    self.inactive_requests[index] = None
                    self.waschecked[index] = True
                    self.amount_left -= self._piecelen(index)
                    if self.amount_left == 0:
                        self.finished()
                else:
                    self.data_flunked(self._piecelen(index), index=index, ip=ip)
                    self.inactive_requests[index] = 1
                    self.amount_inactive += self._piecelen(index)
                    return False
                    
        return True

    def request_lost(self, index, begin, length):
        self.inactive_requests[index].append((begin, length))
        self.amount_inactive += length
        self.numactive[index] -= 1

    def get_piece(self, index, begin, length):
        try:
            return self._get_piece(index, begin, length)
        except IOError, e:
            self.failed('IO Error ' + str(e))
            return None
            
    def _get_piece(self, index, begin, length):
        if not self.have[index]:
            return None
        if not self.waschecked[index]:
            
            if mhash_flag:
                if MHASH(MHASH_SHA1, (self.storage.read(self.piece_size * self.places[index], self._piecelen(index)))).digest() != self.hashes[index]:
                    self.failed(_('told file complete on start-up, but piece failed hash check'))
                    return None                
            else:
                if sha(self.storage.read(self.piece_size * self.places[index], self._piecelen(index))).digest() != self.hashes[index]:
                    self.failed(_('told file complete on start-up, but piece failed hash check'))
                    return None
                
            self.waschecked[index] = True            
        if begin + length > self._piecelen(index):
            return None
        return self.storage.read(self.piece_size * self.places[index] + begin, length)
    
    #needed as becoding converts True->1 and False->0
    def fixarray(self, data):
        for i in xrange(len(data)):
            if data[i] == 1:
                data[i] = True
            elif data[i] == 0:
                data[i] = False
        return data

    
    def _LoadFastResume(self):
        if self.resumefile != None and os.path.exists(self.resumefile):
#            print "Loading resume data"
            try:
                fp = open(self.resumefile, 'rb')
                filedata = fp.read()
                fp.close()
                resumedata = bdecode(filedata)
                return resumedata
            except IOError, msg:
                self.errorfunc(_("Hashes are being rechecked as an IOerror occurred when reading fast resume data -")+str(msg), -2)
                pass
        return None

    def SaveFastResume(self):
        if self.resumefile != None:
            if self.numchecked == self.numpieces:
#                print "Saving resume data"
                try:
                    filedata = None
                    datesizedata = self.storage.GetDateSizeData()
                    if len(self.have.array) == self.numpieces:
                        resumedata = []
                        resumedata.append(self.have.array) #piece data (array)
                        resumedata.append(datesizedata) #date/size info
                        resumedata.append(self.fastversion) #version
                        if mhash_flag:
                            checksum = MHASH(MHASH_SHA1, ( str(self.have.array) )).digest()
                        else:        
                            checksum = sha( str(self.have.array) ).digest()
                        resumedata.append(checksum) #array SHA-1 checksum
                        filedata = bencode(resumedata)
                    fp = open(self.resumefile, 'wb')
                    fp.write(filedata)
                    fp.close()
                except IOError, msg:
                    self.errorfunc(_("Fast resume data was not saved as an IOerror exception occurred when writing fast resume data -")+str(msg), -2)
                    pass
            elif self.numchecked != self.numpieces:
                self.errorfunc(_("Fast Resume data was not saved as initial hashing was interrupted or previous fast resume file was corrupt"), -2)            
                self._KillResumeData()

    def _KillResumeData (self):
        if os.path.exists(self.resumefile):
            os.remove(self.resumefile) #kill file if it exists
