#!/usr/bin/env python3

import requests
import hmac
import hashlib
import base64
import argparse
import sys
import csv
import json
from datetime import datetime

publicKey = 'PUBLIC_KEY_HERE'
privateKey = 'PRIVATE_KEY_HERE'

def main():
    # Get thread metadata from document ID
    payload = {
        'q': args.docId, 'sort': 'p', 'count': 1
            }

    response = perform_query(payload)
    if response['total'] == 0:
        print('No post found. Check to make sure your document ID is correct and try again')
        exit(1)

    thread_title = response['results'][0]['title']    
    thread_id = response['results'][0]['forum']['threadId']  
    domain = response['results'][0]['domain']

    if 'siteId' in response['results'][0]:
        site = response['results'][0]['siteId']
    else:
        site = None    

    print(f'Thread title: {thread_title}')
    print(f'Thread ID: {thread_id}')
    print(f'Site ID: {site}')
    print(f'Site Domain: {domain}')
    print('-'*50)

    # prepare thread ID query to retrieve all posts
    payload = {
        'domain': domain, 'threadId': thread_id, 'sort': '-p'
            }  

    # dict to hold all of the thread metadata and posts
    thread_info = {
        'Thread Title': thread_title,
        'Site ID': site,
        'Thread URL': None,
        'Thread Posts': []
    }

    num_pages = None
    while True:
        #paginate to get up to 5000 posts before updating the filter
        for offset in range(0, 5000, 20):
            payload['offset'] = offset   
            print(f'Getting Page {(offset//20)+1} of ', end='')
            response = perform_query(payload)

            # This means it's the first page/first post, so we can get the # posts and OP URL
            if num_pages is None:
                num_pages = (response['total'] // 20)  + (1 if response['total'] % 20 > 0 else 0)
                thread_info['Thread URL'] = response['results'][0]['location']

            print(num_pages)

            # add the posts to the thread info list
            for count, post in enumerate(response['results']):
                thread_info['Thread Posts'].append({
                    'postAuthor': post['forum']['postAuthor'],
                    'postDate': post['forum']['postDate'],
                    #'threadTitle': post['title'],
                    'body': post['body'] if 'body' in post else None
                })

            # break if there are fewer than 20 posts in the response, meaning this is the last page
            if response['resultCount'] < 20:
                break

            # if we're at the pagination limit, make the next batch of 5000 filter from the last postDate to start from there    
            if offset == 4980:
                payload['postedFrom'] = response['results'][-1]['forum']['postDate']  

        if response['resultCount'] < 20:
                break      
  
    #create filename, truncate if too long
    filename = f'{site} {thread_title}'.replace('/', ' ')[:200]

    #save JSON output
    with open(f'./{filename}.json','w') as outFile:
        json.dump(thread_info, outFile)

    # save CSV output
    with open(f'./{filename}.csv', 'w') as outFile:
        writer = csv.DictWriter(outFile, fieldnames=thread_info['Thread Posts'][0].keys())
        writer.writeheader()
        for post in thread_info['Thread Posts']:
            writer.writerow(post)

    print('\nThread Download Complete')        


def perform_query(payload):
    host = 'api.darkowl.com'
    endpoint = '/api/v1/search'

    #Generate search string
    search = payloadToString(payload)
    url = f'https://{host}{endpoint}{search}'
    absPath = endpoint + search

    headers = generate_headers(absPath)
    r = requests.get(url, headers=headers)
    if r.status_code == 200:
        return r.json()
    else:
        print(r.content)
        exit(1)

def generate_headers(abs_path, http_method='GET'):
    date = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')

    string2hash = http_method + abs_path + date
    bkey = bytes(source=privateKey, encoding='UTF-8')
    bpayload = bytes(source=string2hash, encoding='UTF-8')
    hmacsha1 = hmac.new(bkey, bpayload, hashlib.sha1).digest()
    base64encoded = base64.b64encode(hmacsha1).decode('UTF-8')
    auth_header = f'OWL {publicKey}:{base64encoded}'

    headers = {'Authorization': auth_header, 'X-VISION-DATE': date, 'Accept': 'application/json'}

    return headers        


### Takes a payload and generates a URL query string
def payloadToString(payload):
    search = ''
    count = 0
    for key, value in payload.items():
        if count == 0:
            search += f'?{key}={value}'
            count = 1
        else:
            search += f'&{key}={value}'
    return search

def parse_command_line():
    description = f'''
    Tool to download an entire forum thread from the DarkOwl Vision Search API. Provide one argument: the document ID of any
    post within the thread, it does not have to be the original post. Any post within the thread will download the entire thread.

    Usage:

    {sys.argv[0]} document_id
    '''

    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, description=description)
    parser.add_argument('docId', help="Doc ID of any post in the thread.", type=str)
    args = parser.parse_args()

    return args
  

if __name__ == '__main__':
    args = parse_command_line()
    main()
