import json
import sqlite3
import base64
import io
import os
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import google.generativeai as genai
from functools import wraps

app = Flask(__name__)
CORS(app)

# Load Gemini API key from config.json
try:
    with open('config.json') as f:
        config = json.load(f)
    GEMINI_API_KEY = config.get('GEMINI_API_KEY')
    ADMIN_PASSWORD = config.get('ADMIN_PASSWORD', 'admin123')  # Default password for development
    if not GEMINI_API_KEY:
        raise RuntimeError('GEMINI_API_KEY not set in config.json')
except FileNotFoundError:
    raise RuntimeError('config.json file not found. Please create it with your GEMINI_API_KEY.')

genai.configure(api_key=GEMINI_API_KEY)

# Database setup
def get_db():
    conn = sqlite3.connect('orders.db')
    conn.row_factory = sqlite3.Row
    return conn

# Create table if not exists
with get_db() as conn:
    conn.execute('''
        CREATE TABLE IF NOT EXISTS orders (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            order_name TEXT,
            order_number TEXT,
            order_time TEXT,
            delivery_carrier TEXT,
            image_data BLOB,
            timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
            notes TEXT
        )
    ''')
    
    # Add notes column if it doesn't exist (for existing databases)
    try:
        conn.execute('ALTER TABLE orders ADD COLUMN notes TEXT')
    except sqlite3.OperationalError:
        # Column already exists
        pass

def admin_required(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        auth_header = request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith('Bearer '):
            return jsonify({'error': 'Unauthorized'}), 401
        
        token = auth_header.split(' ')[1]
        if token != ADMIN_PASSWORD:
            return jsonify({'error': 'Invalid token'}), 401
        
        return f(*args, **kwargs)
    return decorated_function

@app.route('/api/admin/login', methods=['POST'])
def admin_login():
    data = request.get_json()
    password = data.get('password')
    
    if password == ADMIN_PASSWORD:
        return jsonify({'token': ADMIN_PASSWORD})
    return jsonify({'error': 'Invalid password'}), 401

@app.route('/api/extract', methods=['POST'])
def extract():
    data = request.get_json()
    image_base64 = data.get('imageBase64')
    if not image_base64:
        return jsonify({'error': 'Missing imageBase64'}), 400

    # Remove data URL prefix if present
    if image_base64.startswith('data:'):
        image_base64 = image_base64.split(',')[1]

    model = genai.GenerativeModel('gemini-2.5-flash-preview-04-17')
    prompt = (
        "Analyze the provided image of a receipt and extract the following information:\n"
        "- Order Name (customer's name or a relevant identifier, e.g., 'Emely A', 'Takeout Guest')\n"
        "- Order Number (unique identifier for the order, e.g., 'DD d9b7ce91', '12345B', 'Order #55')\n"
        "- Order Time (time the order was placed or is due, e.g., '12:30 PM', '05:45PM')\n"
        "- Delivery Carrier (e.g., 'Uber Eats', 'DoorDash', 'GrubHub', etc.)\n\n"
        "Return ONLY JSON in the format: {\"orderName\": \"...\", \"orderNumber\": \"...\", \"orderTime\": \"...\", \"deliveryCarrier\": \"...\"}.\n"
        "If a field is unclear or not found, use null or an empty string for its value. Prioritize accuracy for order number and time.\n"
        "For Order Name, if multiple names or identifiers are present, pick the most prominent one related to the customer or order."
    )
    try:
        response = model.generate_content([
            {
                'inline_data': {
                    'mime_type': 'image/jpeg',
                    'data': image_base64,
                }
            },
            prompt
        ], generation_config={
            'response_mime_type': 'application/json',
            'temperature': 0.1
        })
        text = response.text.strip()
        # Remove code fences if present
        if text.startswith('```'):
            text = text.strip('`').split('\n', 1)[-1].rsplit('\n', 1)[0]
        result = json.loads(text)
        return jsonify({
            'orderName': result.get('orderName') or '',
            'orderNumber': result.get('orderNumber') or '',
            'orderTime': result.get('orderTime') or '',
            'deliveryCarrier': result.get('deliveryCarrier') or ''
        })
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/api/orders', methods=['POST'])
def save_order():
    data = request.get_json()
    image_base64 = data.get('imageBase64')
    order_name = data.get('orderName')
    order_number = data.get('orderNumber')
    order_time = data.get('orderTime')
    delivery_carrier = data.get('deliveryCarrier')

    # Convert base64 to binary
    image_data = base64.b64decode(image_base64.split(',')[1])

    # Save to database
    with get_db() as conn:
        conn.execute(
            'INSERT INTO orders (order_name, order_number, order_time, delivery_carrier, image_data) VALUES (?, ?, ?, ?, ?)',
            (order_name, order_number, order_time, delivery_carrier, image_data)
        )

    return jsonify({'success': True})

@app.route('/api/orders', methods=['GET'])
@admin_required
def list_orders():
    # Get pagination parameters
    page = request.args.get('page', 1, type=int)
    per_page = request.args.get('per_page', 50, type=int)  # Default 50 orders per page
    search = request.args.get('search', '', type=str)
    carrier = request.args.get('carrier', '', type=str)
    date_from = request.args.get('date_from', '', type=str)
    date_to = request.args.get('date_to', '', type=str)
    sort_by = request.args.get('sort_by', 'timestamp', type=str)
    sort_order = request.args.get('sort_order', 'desc', type=str)
    
    # Validate per_page limits
    per_page = min(per_page, 100)  # Max 100 orders per page
    offset = (page - 1) * per_page
    
    # Build WHERE clause for filtering
    where_conditions = []
    params = []
    
    if search:
        where_conditions.append("(order_name LIKE ? OR order_number LIKE ? OR delivery_carrier LIKE ?)")
        search_param = f"%{search}%"
        params.extend([search_param, search_param, search_param])
    
    if carrier:
        where_conditions.append("delivery_carrier = ?")
        params.append(carrier)
    
    if date_from:
        where_conditions.append("date(timestamp) >= ?")
        params.append(date_from)
    
    if date_to:
        where_conditions.append("date(timestamp) <= ?")
        params.append(date_to)
    
    where_clause = ""
    if where_conditions:
        where_clause = "WHERE " + " AND ".join(where_conditions)
    
    # Validate sort parameters
    valid_sort_columns = ['timestamp', 'order_name', 'order_number', 'delivery_carrier']
    if sort_by not in valid_sort_columns:
        sort_by = 'timestamp'
    
    if sort_order.lower() not in ['asc', 'desc']:
        sort_order = 'desc'
    
    with get_db() as conn:
        # Get total count for pagination
        count_query = f"SELECT COUNT(*) as total FROM orders {where_clause}"
        total_count = conn.execute(count_query, params).fetchone()['total']
        
        # Get paginated orders
        orders_query = f"""
            SELECT id, order_name, order_number, order_time, delivery_carrier, timestamp, notes
            FROM orders {where_clause} 
            ORDER BY {sort_by} {sort_order.upper()} 
            LIMIT ? OFFSET ?
        """
        orders = conn.execute(orders_query, params + [per_page, offset]).fetchall()
        
        # Calculate pagination info
        total_pages = (total_count + per_page - 1) // per_page
        has_next = page < total_pages
        has_prev = page > 1
        
        return jsonify({
            'orders': [dict(order) for order in orders],
            'pagination': {
                'page': page,
                'per_page': per_page,
                'total': total_count,
                'total_pages': total_pages,
                'has_next': has_next,
                'has_prev': has_prev
            }
        })

@app.route('/api/orders/<int:order_id>/image')
@admin_required
def serve_image(order_id):
    with get_db() as conn:
        order = conn.execute('SELECT image_data FROM orders WHERE id = ?', (order_id,)).fetchone()
    if not order:
        return jsonify({'error': 'Order not found'}), 404
    return send_file(
        io.BytesIO(order['image_data']),
        mimetype='image/jpeg'
    )

@app.route('/api/orders/<int:order_id>', methods=['PUT'])
@admin_required
def update_order(order_id):
    data = request.get_json()
    order_name = data.get('order_name')
    order_number = data.get('order_number')
    order_time = data.get('order_time')
    delivery_carrier = data.get('delivery_carrier')
    notes = data.get('notes')
    
    with get_db() as conn:
        # Check if order exists
        existing = conn.execute('SELECT id FROM orders WHERE id = ?', (order_id,)).fetchone()
        if not existing:
            return jsonify({'error': 'Order not found'}), 404
            
        # Update order
        conn.execute('''
            UPDATE orders 
            SET order_name = ?, order_number = ?, order_time = ?, delivery_carrier = ?, notes = ?
            WHERE id = ?
        ''', (order_name, order_number, order_time, delivery_carrier, notes, order_id))
        
    return jsonify({'success': True})

@app.route('/api/orders/<int:order_id>', methods=['DELETE'])
@admin_required
def delete_order(order_id):
    with get_db() as conn:
        # Check if order exists
        existing = conn.execute('SELECT id FROM orders WHERE id = ?', (order_id,)).fetchone()
        if not existing:
            return jsonify({'error': 'Order not found'}), 404
            
        # Delete order
        conn.execute('DELETE FROM orders WHERE id = ?', (order_id,))
        
    return jsonify({'success': True})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000) 