#!/usr/bin/python3
import sqlite3
import os.path
import sys
import subprocess
import zipfile
import xml.etree.ElementTree
import re
import chardet
import config
from multiprocessing import Pool



class pagedata:
	page = 0
	content = ""
	
	
def singlepagelist(content):
	result = pagedata()
	result.content = content
	result.page = 0
	l = list();
	l.append(result)
	return l

def striptags(content):
	result = ""
	try:
		result = ''.join(xml.etree.ElementTree.fromstring(content).itertext())
	except:
		#TODO: test<br>test2 will make it testtest2 not test test2
		result = re.sub('<[^>]*>', '', content)
	
	return result
	

def strip_irrelevant(content):
	result = content.replace("\n", " ").replace("\t", " ")
	result = re.sub(' +', ' ', result)
	return result;



def process_pdf(path):
	result = list()
	args=["pdftotext", path, "-"]
	stdout,stderr =  subprocess.Popen(args,stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
	content = strip_irrelevant(stdout.decode('utf-8'))
	#it is faster to do it like this than to call pdfottext for each page
	splitted = content.split("\f")
	count=1
	for page in splitted: 
		data = pagedata()
		data.page = count
		data.content = page
		result.append(data)
		count += 1
	
	#TODO: current hack, so we can fts search several words over the whole document
	#this of course uses more space, but in the end that's not a big problem
	#Nevertheless, this remains a hack
	everything = pagedata()
	everything.page = 0
	everything.content = content.replace("\f", "")
	result.append(everything)
	return result
	
def process_odt(path):
	fd = zipfile.ZipFile(path)
	content = fd.read("content.xml").decode("utf-8")
	fd.close()
	return singlepagelist(striptags(content))
	
def readtext(path):
	fd = open(path, "rb")
	content = fd.read()
	fd.close()

	result=""
	try:
		result = str(content.decode("utf-8"))
	except:
		try:
			encoding = chardet.detect(content)["encoding"];
			if encoding == None:
				result = ""
			else:
				result = str(content.decode(encoding))
		except:
			print("FAILE DECODING: " + path)
			result = ""
	return result
	
def process_striptags(path):
	content = readtext(path)
	return singlepagelist(striptags(content))
	
def process_text(path):
	return singlepagelist(readtext(path))
	
def process_nothing(path):
	return list()
	
def exists(abspath, mtime):
	cursor = dbcon.cursor()
	cursor.execute("SELECT 1 FROM file WHERE path = ? AND mtime = ?" , (abspath, mtime))
	result = cursor.fetchone()
	if result != None and result[0] == 1:
		return True
	return False

def insert(path):
	print("processing", path)
	abspath=os.path.abspath(path)
	mtime = int(os.stat(abspath).st_mtime)
	
	if exists(abspath, mtime):
            print("Leaving alone " + abspath + " because it wasn't changed")
            return
	basename=os.path.basename(abspath)
	ext = os.path.splitext(abspath)[1]
	
	content=""

	processor=process_nothing
	if ext in preprocess:
		processor=preprocess[ext]
	pagedatalist = processor(abspath)

	#TODO: assumes sqlitehas been built with thread safety (and it is the default)
	cursor = dbcon.cursor()
	cursor.execute("BEGIN TRANSACTION")
	cursor.execute("DELETE FROM file WHERE path = ?", (abspath,))
	cursor.execute("INSERT INTO file(path, mtime) VALUES(?, ?) ", (abspath, mtime))
	fileid=cursor.lastrowid
	for pagedata in pagedatalist:
		cursor.execute("INSERT INTO content(fileid, page, content) VALUES(?, ?, ?)", (fileid, pagedata.page, pagedata.content))
	cursor.execute("COMMIT TRANSACTION")

preprocess={".pdf":process_pdf, ".odt":process_odt, ".html":process_striptags, ".xml":process_nothing, ".txt":process_text, 
			".sql":process_text, ".c":process_text, ".cpp":process_text, ".js":process_text, ".java":process_text, 
			".py":process_text, '.md':process_text}	



def yieldstdinfiles():
	for line in sys.stdin:
		yield line.replace("\n", "")

	
def init():
	global dbcon
	dbcon = sqlite3.connect(config.DBPATH, isolation_level=None)


dbcon = None
if __name__ == '__main__':
	with Pool(processes=4,initializer=init) as pool:
		
		if len(sys.argv) < 2:
			pool.map(insert, (l for l in yieldstdinfiles()))
		else:
			pool.map(insert, sys.argv[1:])