from fastapi import APIRouter, Depends, HTTPException, Response
from typing import List, Dict
import subprocess
import re
import logging
import hashlib
import shlex

from auth import get_current_active_user

router = APIRouter()

def get_base_and_log_dir():
    from app import BASE_DIR, LOG_DIR
    return BASE_DIR, LOG_DIR

def get_firewall_paths():
    BASE_DIR, LOG_DIR = get_base_and_log_dir()
    log_path = LOG_DIR / "firewall.log"
    if not log_path.exists():
        log_path.parent.mkdir(parents=True, exist_ok=True)
        log_path.touch(exist_ok=True)
    return log_path

REQUIRED_RULES = [
    {
        "chain": "INPUT",
        "protocol": "",
        "src": "",
        "dst": "",
        "dport": "",
        "sport": "",
        "in_interface": "lo",
        "out_interface": "",
        "target": "ACCEPT",
        "match": ""
    },
    {
        "chain": "OUTPUT",
        "protocol": "",
        "src": "",
        "dst": "",
        "dport": "",
        "sport": "",
        "in_interface": "",
        "out_interface": "lo",
        "target": "ACCEPT",
        "match": ""
    },
    # ESTABLISHED,RELATED
#    {
#        "chain": "INPUT",
#        "protocol": "",
#        "src": "",
#        "dst": "",
#        "dport": "",
#        "sport": "",
#        "in_interface": "",
#        "out_interface": "",
#        "target": "ACCEPT",
#        "match": "-m conntrack --ctstate ESTABLISHED,RELATED"
#    },
#    {
#        "chain": "OUTPUT",
#        "protocol": "",
#        "src": "",
#        "dst": "",
#        "dport": "",
#        "sport": "",
#        "in_interface": "",
#        "out_interface": "",
#        "target": "ACCEPT",
#        "match": "-m conntrack --ctstate ESTABLISHED,RELATED"
#    }
]

def rule_to_cmd(rule):
    cmd = ["sudo", "iptables", "-A", rule["chain"]]
    if rule.get("protocol"):
        cmd += ["-p", rule["protocol"]]
    if rule.get("src"):
        cmd += ["-s", rule["src"]]
    if rule.get("dst"):
        cmd += ["-d", rule["dst"]]
    if rule.get("dport"):
        cmd += ["--dport", str(rule["dport"])]
    if rule.get("sport"):
        cmd += ["--sport", str(rule["sport"])]
    if rule.get("in_interface"):
        cmd += ["-i", rule["in_interface"]]
    if rule.get("out_interface"):
        cmd += ["-o", rule["out_interface"]]
    if rule.get("match"):
        cmd += rule["match"].split()
    if rule.get("comment"):
        cmd += ["-m", "comment", "--comment", str(rule["comment"])]
    cmd += ["-j", rule["target"]]
    return cmd

def required_rule_exists(rule):
    result = subprocess.run(["sudo", "iptables", "-S", rule["chain"]], capture_output=True, text=True)
    lines = result.stdout.strip().splitlines()
    count = 0
    
    for line in lines:
        if not line.startswith("-A"):
            continue
        parts = line.split()
        if rule["chain"] != parts[1]:
            continue
        if rule["in_interface"] and "-i" in parts:
            if parts[parts.index("-i")+1] != rule["in_interface"]:
                continue
        elif rule["in_interface"] and "-i" not in parts:
            continue
        if rule["out_interface"] and "-o" in parts:
            if parts[parts.index("-o")+1] != rule["out_interface"]:
                continue
        elif rule["out_interface"] and "-o" not in parts:
            continue
        if rule["match"]:
            match_parts = rule["match"].split()
            if not all(m in parts for m in match_parts):
                continue
        if rule["target"] and "-j" in parts:
            if parts[parts.index("-j")+1] != rule["target"]:
                continue
        elif rule["target"] and "-j" not in parts:
            continue
        count += 1
        
    return count

def ensure_required_rules():
    for rule in REQUIRED_RULES:
        count = required_rule_exists(rule)
        if count == 0:
            cmd = rule_to_cmd(rule)
            subprocess.run(cmd)
        elif count > 1:
            delete_specific_rule(rule)
            cmd = rule_to_cmd(rule)
            subprocess.run(cmd)
            
def delete_specific_rule(rule):
    result = subprocess.run(["sudo", "iptables", "-S", rule["chain"]], capture_output=True, text=True)
    lines = result.stdout.strip().splitlines()

    def normalize_comment(val):
        if val is None:
            return ""
        return str(val).strip('"').strip()

    rule_comment = normalize_comment(rule.get("comment"))

    for line in lines:
        if not line.startswith("-A"):
            continue

        parts = shlex.split(line)
        if rule["chain"] != parts[1]:
            continue

        match = True
        for key, flag in [
            ("protocol", "-p"),
            ("src", "-s"),
            ("dst", "-d"),
            ("dport", "--dport"),
            ("sport", "--sport"),
            ("in_interface", "-i"),
            ("out_interface", "-o"),
        ]:
            if rule.get(key):
                if flag in parts:
                    if parts[parts.index(flag) + 1] != str(rule[key]):
                        match = False
                        break
                else:
                    match = False
                    break

        if match and rule.get("match"):
            match_parts = rule["match"].split()
            if not all(m in parts for m in match_parts):
                match = False

        if match and rule.get("comment"):
            if "--comment" in parts:
                comment_val = normalize_comment(parts[parts.index("--comment") + 1])
                if comment_val != rule_comment:
                    match = False
            else:
                match = False

        if match and rule.get("target"):
            if "-j" in parts:
                if parts[parts.index("-j") + 1] != rule["target"]:
                    match = False
            else:
                match = False

        if match:
            delete_cmd = ["sudo", "iptables"] + ["-D"] + parts[1:]
            subprocess.run(delete_cmd)

def apply_iptables_rules(rules: List[Dict], clear_first=True):
    if clear_first:
        default_policies = {}
        for chain in ["INPUT", "OUTPUT", "FORWARD"]:
            proc = subprocess.run(["sudo", "iptables", "-nL", chain], capture_output=True, text=True)
            m = re.search(rf"Chain {chain} \(policy (\w+)\)", proc.stdout)
            default_policies[chain] = m.group(1) if m else "ACCEPT"
        subprocess.run(["sudo", "iptables", "-F"])
        for chain, policy in default_policies.items():
            subprocess.run(["sudo", "iptables", "-P", chain, policy])
    else:
        cleanup_duplicate_rules()
    ensure_required_rules()
    applied_rules = set()
    for rule in rules:
        rule_hash = hashlib.sha256(str({k: v for k, v in rule.items() if k != "id"}).encode()).hexdigest()
        if rule_hash in applied_rules:
            continue
        applied_rules.add(rule_hash)
        cmd = rule_to_cmd(rule)
        subprocess.run(cmd)

def cleanup_duplicate_rules():
    result = subprocess.run(["sudo", "iptables", "-S"], capture_output=True, text=True)
    lines = result.stdout.strip().splitlines()
    unique_rules = {}
    duplicates = []
    for line in lines:
        if not line.startswith("-A"):
            continue
        if line in unique_rules:
            duplicates.append(line)
        else:
            unique_rules[line] = True
    for dup in duplicates:
        parts = dup.split()
        delete_cmd = ["sudo", "iptables"] + ["-D"] + parts[1:]
        subprocess.run(delete_cmd)
    return len(duplicates)

def initialize_firewall():
    removed = cleanup_duplicate_rules()
    if removed > 0:
        logging.info(f"Удалено {removed} дублирующихся правил файрвола")
    ensure_required_rules()

initialize_firewall()

def validate_rule(rule: Dict):
    if "chain" not in rule or "target" not in rule:
        raise HTTPException(status_code=400, detail="Missing required fields (chain or target)")
    if rule["chain"] not in ["INPUT", "OUTPUT", "FORWARD"]:
        raise HTTPException(status_code=400, detail="Invalid chain value")
    dangerous_chars = [';', '&&', '||', '`', '$', '|', '>', '<']
    fields_to_check = ["protocol", "src", "dst", "in_interface", "out_interface", "match"]
    for field in fields_to_check:
        if field in rule and any(char in str(rule[field]) for char in dangerous_chars):
            raise HTTPException(status_code=400, detail=f"Invalid characters in {field}")
    if "dport" in rule and rule["dport"]:
        try:
            port = int(rule["dport"])
            if port < 1 or port > 65535:
                raise ValueError()
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid dport value")
    if "sport" in rule and rule["sport"]:
        try:
            port = int(rule["sport"])
            if port < 1 or port > 65535:
                raise ValueError()
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid sport value")
    if "comment" in rule and rule["comment"]:
        if not isinstance(rule["comment"], str) or len(rule["comment"]) > 128:
            raise HTTPException(status_code=400, detail="Comment too long (max 128 chars)")
        dangerous_chars = [';', '&&', '||', '`', '$', '|', '>', '<', '"']
        if any(char in rule["comment"] for char in dangerous_chars):
            raise HTTPException(status_code=400, detail="Invalid characters in comment")

def is_required_rule(rule: Dict) -> bool:
    # Сравниваем только по ключам, которые есть в REQUIRED_RULES (без id)
    for req in REQUIRED_RULES:
        match = True
        for key in req:
            if req[key] != rule.get(key, ""):
                match = False
                break
        if match:
            return True
    return False

def get_firewall_rules_from_iptables():
    result = subprocess.run(["sudo", "iptables", "-S"], capture_output=True, text=True)
    lines = result.stdout.strip().splitlines()
    rules = []
    for line in lines:
        if not line.startswith("-A"):
            continue
        parts = shlex.split(line)
        rule = {}
        rule["chain"] = parts[1]
        idx = 2
        while idx < len(parts):
            part = parts[idx]
            if part == "-p":
                rule["protocol"] = parts[idx + 1]
                idx += 2
            elif part == "-s":
                rule["src"] = parts[idx + 1]
                idx += 2
            elif part == "-d":
                rule["dst"] = parts[idx + 1]
                idx += 2
            elif part == "--dport":
                rule["dport"] = parts[idx + 1]
                idx += 2
            elif part == "--sport":
                rule["sport"] = parts[idx + 1]
                idx += 2
            elif part == "-i":
                rule["in_interface"] = parts[idx + 1]
                idx += 2
            elif part == "-o":
                rule["out_interface"] = parts[idx + 1]
                idx += 2
            elif part == "-m" and idx + 1 < len(parts) and parts[idx + 1] == "comment":
                idx += 2
                if idx < len(parts) and parts[idx] == "--comment":
                    rule["comment"] = parts[idx + 1]
                    idx += 2
            elif part == "-j":
                rule["target"] = parts[idx + 1]
                idx += 2
            else:
                if "match" not in rule:
                    rule["match"] = ""
                rule["match"] += part + " "
                idx += 1
        rule_id = hashlib.sha256(str({k: v for k, v in rule.items() if k != "id"}).encode()).hexdigest()[:12]
        rule["id"] = rule_id
        rules.append(rule)
    # Исключаем служебные правила
    filtered_rules = [r for r in rules if not is_required_rule(r)]
    return filtered_rules

@router.get("/rules", response_model=List[Dict])
async def get_firewall_rules(current_user: Dict = Depends(get_current_active_user)):
    try:
        removed = cleanup_duplicate_rules()
        if removed > 0:
            logging.info(f"Удалено {removed} дублирующихся правил при запросе")
        ensure_required_rules()
        return get_firewall_rules_from_iptables()
    except Exception as e:
        logging.error(f"Ошибка получения правил firewall: {e}")
        raise HTTPException(status_code=500, detail="Error loading firewall rules")

@router.post("/rules")
async def add_firewall_rule(rule: Dict, current_user: Dict = Depends(get_current_active_user)):
    try:
        validate_rule(rule)
        current_rules = get_firewall_rules_from_iptables()
        rule_without_id = {k: v for k, v in rule.items() if k != "id"}
        for existing_rule in current_rules:
            existing_without_id = {k: v for k, v in existing_rule.items() if k != "id"}
            if existing_without_id == rule_without_id:
                return {"status": "exists", "id": existing_rule["id"]}
        rule_id = hashlib.sha256(str(rule_without_id).encode()).hexdigest()[:12]
        rule["id"] = rule_id
        cleanup_duplicate_rules()
        ensure_required_rules()
        cmd = rule_to_cmd(rule)
        subprocess.run(cmd)
        return {"status": "ok", "id": rule_id}
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Ошибка добавления правила firewall: {e}")
        raise HTTPException(status_code=500, detail=f"Error adding firewall rule: {str(e)}")

@router.delete("/rules/{rule_id}")
async def delete_firewall_rule(rule_id: str, current_user: Dict = Depends(get_current_active_user)):
    try:
        current_rules = get_firewall_rules_from_iptables()
        rule_to_delete = None
        for r in current_rules:
            if r.get("id") == rule_id:
                rule_to_delete = r
                break
        if not rule_to_delete:
            raise HTTPException(status_code=404, detail="Rule not found")
        rule_for_deletion = {k: v for k, v in rule_to_delete.items() if k != "id"}
        try:
            delete_specific_rule(rule_for_deletion)
        except Exception as e:
            logging.error(f"Ошибка удаления правила из iptables: {e}")
        ensure_required_rules()
        return {"status": "ok"}
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Ошибка удаления правила firewall: {e}")
        raise HTTPException(status_code=500, detail=f"Error deleting firewall rule: {str(e)}")

@router.put("/rules/{rule_id}")
async def update_firewall_rule(rule_id: str, rule: Dict, current_user: Dict = Depends(get_current_active_user)):
    try:
        validate_rule(rule)
        current_rules = get_firewall_rules_from_iptables()
        old_rule = None
        found = False
        for r in current_rules:
            if r.get("id") == rule_id:
                old_rule = r.copy()
                found = True
                break
        if not found:
            raise HTTPException(status_code=404, detail="Rule not found")
        rule["id"] = rule_id
        rule_without_id = {k: v for k, v in rule.items() if k != "id"}
        for existing_rule in current_rules:
            if existing_rule["id"] == rule_id:
                continue
            existing_without_id = {k: v for k, v in existing_rule.items() if k != "id"}
            if existing_without_id == rule_without_id:
                raise HTTPException(status_code=400, detail="Duplicate rule would be created")
        cleanup_duplicate_rules()
        ensure_required_rules()
        if old_rule and old_rule != rule:
            old_rule_without_id = {k: v for k, v in old_rule.items() if k != "id"}
            try:
                delete_specific_rule(old_rule_without_id)
            except:
                pass
        cmd = rule_to_cmd(rule)
        subprocess.run(cmd)
        return {"status": "ok"}
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Ошибка обновления правила firewall: {e}")
        raise HTTPException(status_code=500, detail=f"Error updating firewall rule: {str(e)}")

@router.get("/default_policy")
async def get_default_policy(current_user: Dict = Depends(get_current_active_user)):
    try:
        result = {}
        for chain in ["INPUT", "OUTPUT", "FORWARD"]:
            proc = subprocess.run(["sudo", "iptables", "-nL", chain], capture_output=True, text=True)
            m = re.search(rf"Chain {chain} \(policy (\w+)\)", proc.stdout)
            result[chain] = m.group(1) if m else "ACCEPT"
        return result
    except Exception as e:
        logging.error(f"Ошибка получения политики по умолчанию: {e}")
        raise HTTPException(status_code=500, detail=f"Error getting default policy: {str(e)}")

@router.post("/default_policy")
async def set_default_policy(data: Dict, current_user: Dict = Depends(get_current_active_user)):
    try:
        chain = data.get("chain")
        policy = data.get("policy")
        if chain not in ("INPUT", "OUTPUT", "FORWARD"):
            raise HTTPException(status_code=400, detail="Invalid chain")
        if policy not in ("ACCEPT", "DROP", "REJECT"):
            raise HTTPException(status_code=400, detail="Invalid policy")
        subprocess.run(["sudo", "iptables", "-P", chain, policy])
        return {"status": "ok"}
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Ошибка установки политики по умолчанию: {e}")
        raise HTTPException(status_code=500, detail=f"Error setting default policy: {str(e)}")

@router.post("/save")
async def save_firewall_rules(current_user: Dict = Depends(get_current_active_user)):
    try:
        # Проверка прав (только admin или firewall:edit)
        if not (current_user.get("roles") and "admin" in current_user["roles"]) and not any(
            a.get("component") == "firewall" and a.get("access") == "edit"
            for a in current_user.get("access", [])
        ):
            raise HTTPException(status_code=403, detail="Not enough permissions")
        # Сохраняем правила
        result = subprocess.run(
            ["sudo", "iptables-save"],
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            raise Exception(result.stderr)
        # Сохраняем в BASE_DIR/iptables.conf
        from app import BASE_DIR
        save_path = BASE_DIR / "iptables.conf"
        with open(save_path, "w") as f:
            f.write(result.stdout)
        return {"status": "ok"}
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Ошибка сохранения состояния firewall: {e}")
        raise HTTPException(status_code=500, detail=f"Error saving firewall rules: {str(e)}")