use wasm_bindgen::prelude::*;
use num_bigint::{BigInt, Sign};
use num_traits::{Zero, One, Signed, ToPrimitive};
use std::str::FromStr;

const DEFAULT_PRECISION: usize = 256;
const MAX_PRECISION: usize = 10_000;
const MAX_RECURSION_DEPTH: usize = 100;
const MAX_EXPONENT: u32 = 10_000;

#[wasm_bindgen]
pub fn validate_expression(expr: &str) -> bool {
    let mut balance: usize = 0;

    for c in expr.chars() {
        match c {
            // Разрешённые символы
            '0'..='9' |
            '+' | '-' | '*' | '/' |
            '(' | ')' |
            '.' | ' ' |
            '^' |
            // Буквы только для 'sqrt' (в любом регистре)
            's' | 'q' | 'r' | 't' |
            'S' | 'Q' | 'R' | 'T' => {
                // ОК
            }
            // Любой другой символ — недопустим
            _ => return false,
        }

        // Проверка баланса скобок
        if c == '(' {
            balance += 1;
        } else if c == ')' {
            if balance == 0 {
                return false; // закрывающая скобка без открывающей
            }
            balance -= 1;
        }
    }

    // Скобки должны быть сбалансированы
    balance == 0
}

#[wasm_bindgen]
pub fn evaluate_expression(expr: &str, precision: usize) -> Result<String, String> {
    if !validate_expression(expr) {
        return Err("Invalid expression".to_string());
    }

    let final_precision = if precision == 0 {
        DEFAULT_PRECISION
    } else {
        if precision > MAX_PRECISION {
            return Err(format!("Precision too high (max {})", MAX_PRECISION));
        }
        precision
    };

    let clean_expr: String = expr.chars().filter(|c| !c.is_whitespace()).collect();
    let scale = BigInt::from(10).pow(final_precision as u32);

    let (result_value, _) = parse_expression_with_depth(&clean_expr, 0, MAX_RECURSION_DEPTH, &scale)
        .map_err(|e| e)?;

    format_decimal(&result_value, final_precision, &scale)
}

type ScaledInt = BigInt;
type ParseResult = Result<(ScaledInt, usize), String>;

fn format_decimal(value: &ScaledInt, precision: usize, scale: &BigInt) -> Result<String, String> {
    let abs_value = value.abs();
    let integer_part = &abs_value / scale;
    let remainder = &abs_value % scale;

    let integer_str = if integer_part.is_zero() {
        "0".to_string()
    } else {
        integer_part.to_str_radix(10)
    };

    if precision == 0 {
        let sign = if value.is_negative() && !integer_part.is_zero() { "-" } else { "" };
        return Ok(format!("{}{}", sign, integer_str));
    }

    let mut fractional = remainder.to_str_radix(10);
    fractional = format!("{:0>width$}", fractional, width = precision);
    if fractional.len() > precision {
        fractional.truncate(precision);
    }

    let fractional_trimmed = fractional.trim_end_matches('0').to_string();
    if fractional_trimmed.is_empty() {
        let sign = if value.is_negative() && !integer_part.is_zero() { "-" } else { "" };
        return Ok(format!("{}{}", sign, integer_str));
    }

    let needs_sign = value.is_negative();
    let sign = if needs_sign { "-" } else { "" };

    Ok(format!("{}{}.{}", sign, integer_str, fractional_trimmed))
}

// === Чистые арифметические функции ===

fn add_scaled(a: &ScaledInt, b: &ScaledInt) -> ScaledInt {
    a + b
}

fn sub_scaled(a: &ScaledInt, b: &ScaledInt) -> ScaledInt {
    a - b
}

fn mul_scaled(a: &ScaledInt, b: &ScaledInt, scale: &BigInt) -> ScaledInt {
    (a * b) / scale
}

fn div_scaled(a: &ScaledInt, b: &ScaledInt, scale: &BigInt) -> Result<ScaledInt, String> {
    if b.is_zero() {
        return Err("Division by zero".to_string());
    }
    Ok((a * scale) / b)
}

fn pow_scaled(base: &ScaledInt, exp: u32, scale: &BigInt) -> ScaledInt {
    if exp == 0 {
        return scale.clone();
    }
    let mut res = base.clone();
    for _ in 1..exp {
        res = mul_scaled(&res, base, scale);
    }
    res
}

fn sqrt_scaled(x: &ScaledInt, scale: &BigInt) -> Result<ScaledInt, String> {
    if x.is_negative() {
        return Err("Square root of negative number".to_string());
    }
    if x.is_zero() {
        return Ok(BigInt::zero());
    }

    let mut guess = x / BigInt::from(2);
    loop {
        let quotient = div_scaled(x, &guess, scale)?;
        let new_guess = (&guess + &quotient) / BigInt::from(2);
        if (&new_guess - &guess).abs() < BigInt::one() {
            return Ok(new_guess);
        }
        guess = new_guess;
    }
}

// === Парсеры ===

fn parse_expression_with_depth(s: &str, start: usize, max_depth: usize, scale: &BigInt) -> ParseResult {
    parse_add_sub_with_depth(s, start, max_depth, scale)
}

fn parse_add_sub_with_depth(s: &str, start: usize, max_depth: usize, scale: &BigInt) -> ParseResult {
    let (mut res, mut pos) = parse_mul_div_with_depth(s, start, max_depth, scale)?;
    while pos < s.len() {
        match s.chars().nth(pos) {
            Some('+') => {
                let (rhs, new_pos) = parse_mul_div_with_depth(s, pos + 1, max_depth, scale)?;
                res = add_scaled(&res, &rhs);
                pos = new_pos;
            }
            Some('-') => {
                let (rhs, new_pos) = parse_mul_div_with_depth(s, pos + 1, max_depth, scale)?;
                res = sub_scaled(&res, &rhs);
                pos = new_pos;
            }
            _ => break,
        }
    }
    Ok((res, pos))
}

fn parse_mul_div_with_depth(s: &str, start: usize, max_depth: usize, scale: &BigInt) -> ParseResult {
    let (mut res, mut pos) = parse_power_with_depth(s, start, max_depth, scale)?;
    while pos < s.len() {
        match s.chars().nth(pos) {
            Some('*') => {
                let (rhs, new_pos) = parse_power_with_depth(s, pos + 1, max_depth, scale)?;
                res = mul_scaled(&res, &rhs, scale);
                pos = new_pos;
            }
            Some('/') => {
                let (rhs, new_pos) = parse_power_with_depth(s, pos + 1, max_depth, scale)?;
                res = div_scaled(&res, &rhs, scale)?;
                pos = new_pos;
            }
            _ => break,
        }
    }
    Ok((res, pos))
}

fn parse_power_with_depth(s: &str, start: usize, max_depth: usize, scale: &BigInt) -> ParseResult {
    let (base, pos) = parse_primary_with_depth(s, start, max_depth, scale)?;
    if pos < s.len() && s.chars().nth(pos) == Some('^') {
        let (exp_val, new_pos) = parse_power_with_depth(s, pos + 1, max_depth, scale)?;
        let exp_big = &exp_val / scale;
        if exp_big.sign() != Sign::Plus {
            return Err("Negative exponents not supported".to_string());
        }
        let exp_u32 = match exp_big.to_u32() {
            Some(e) if e <= MAX_EXPONENT => e,
            _ => return Err("Exponent too large or not integer".to_string()),
        };
        let result = pow_scaled(&base, exp_u32, scale);
        return Ok((result, new_pos));
    }
    Ok((base, pos))
}

fn parse_primary_with_depth(s: &str, start: usize, max_depth: usize, scale: &BigInt) -> ParseResult {
    if start >= s.len() {
        return Err("Unexpected end".to_string());
    }

    let c = s.chars().nth(start).unwrap();

    // Поддержка sqrt(...)
    if start + 4 <= s.len() {
        let substr: String = s.chars().skip(start).take(4).collect();
        if substr.eq_ignore_ascii_case("sqrt") {
            if max_depth == 0 {
                return Err("Recursion depth exceeded".to_string());
            }
            if start + 5 >= s.len() || s.chars().nth(start + 4) != Some('(') {
                return Err("Expected '(' after sqrt".to_string());
            }
            let (inner, pos) = parse_add_sub_with_depth(s, start + 5, max_depth - 1, scale)?;
            if pos >= s.len() || s.chars().nth(pos) != Some(')') {
                return Err("Unmatched parenthesis".to_string());
            }
            let result = sqrt_scaled(&inner, scale)?;
            return Ok((result, pos + 1));
        }
    }

    if c == '(' {
        if max_depth == 0 {
            return Err("Recursion depth exceeded".to_string());
        }
        let (res, pos) = parse_add_sub_with_depth(s, start + 1, max_depth - 1, scale)?;
        if pos >= s.len() || s.chars().nth(pos) != Some(')') {
            return Err("Unmatched parenthesis".to_string());
        }
        return Ok((res, pos + 1));
    }

    if c.is_ascii_digit() || c == '+' || c == '-' || c == '.' {
        return parse_number_scaled(s, start, scale);
    }

    Err("Unexpected character".to_string())
}

fn parse_number_scaled(s: &str, start: usize, scale: &BigInt) -> ParseResult {
    let mut i = start;
    let mut negative = false;

    if i < s.len() {
        let ch = s.chars().nth(i).unwrap();
        if ch == '+' || ch == '-' {
            negative = ch == '-';
            i += 1;
            if i >= s.len() {
                return Err("Incomplete number".to_string());
            }
            let next_ch = s.chars().nth(i).unwrap();
            if !next_ch.is_ascii_digit() && next_ch != '.' {
                return Err("Invalid number format".to_string());
            }
        }
    }

    let start_num = i;
    let mut dot_found = false;
    while i < s.len() {
        let ch = s.chars().nth(i).unwrap();
        if ch.is_ascii_digit() {
            i += 1;
        } else if ch == '.' {
            if dot_found {
                break;
            }
            dot_found = true;
            i += 1;
        } else {
            break;
        }
    }

    if i == start_num {
        return Err("Empty number".to_string());
    }

    let num_str: String = s.chars().skip(start_num).take(i - start_num).collect();
    if num_str == "." {
        return Err("Invalid number: only dot".to_string());
    }

    let (int_part, frac_part) = if dot_found {
        let parts: Vec<&str> = num_str.split('.').collect();
        let ip = if parts[0].is_empty() { "0" } else { parts[0] };
        let fp = parts.get(1).copied().unwrap_or("");
        (ip, fp)
    } else {
        (num_str.as_str(), "")
    };

    let integer_big = BigInt::from_str(int_part).map_err(|_| "Invalid integer part".to_string())?;
    let mut fractional_big = BigInt::zero();

    if !frac_part.is_empty() {
        let precision = scale.to_str_radix(10).len() - 1;
        let frac_use = if frac_part.len() > precision {
            &frac_part[..precision]
        } else {
            frac_part
        };
        if !frac_use.is_empty() {
            fractional_big = BigInt::from_str(frac_use).map_err(|_| "Invalid fractional part".to_string())?;
            let missing = precision - frac_use.len();
            for _ in 0..missing {
                fractional_big *= 10;
            }
        }
    }

    let mut result = integer_big * scale + fractional_big;
    if negative {
        result = -result;
    }

    Ok((result, i))
}

// === ТЕСТЫ ===

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_validate_expression() {
        assert!(validate_expression("1+2"));
        assert!(validate_expression("sqrt(16)"));
        assert!(validate_expression("SQRT(4)"));
        assert!(validate_expression("123.456*(789+0.123)"));
        assert!(validate_expression("2^10"));
        // Теперь abc НЕ проходит валидацию
        assert!(!validate_expression("1+2*abc"));
        assert!(!validate_expression("((1+2)"));
        assert!(!validate_expression("1+2))"));
    }

    #[test]
    fn test_simple_arithmetic() {
        assert_eq!(evaluate_expression("1+2", 0).unwrap(), "3");
        assert_eq!(evaluate_expression("10-4", 0).unwrap(), "6");
        assert_eq!(evaluate_expression("3*4", 0).unwrap(), "12");
        assert_eq!(evaluate_expression("15/3", 0).unwrap(), "5");
    }

    #[test]
    fn test_decimal_precision() {
        assert_eq!(evaluate_expression("1/3", 10).unwrap(), "0.3333333333");
        assert_eq!(evaluate_expression("1/3", 5).unwrap(), "0.33333");
    }

    #[test]
    fn test_sqrt() {
        assert_eq!(evaluate_expression("sqrt(4)", 0).unwrap(), "2");
        assert_eq!(evaluate_expression("sqrt(100)", 0).unwrap(), "10");
        let res = evaluate_expression("sqrt(2)", 5).unwrap();
        assert!(res.starts_with("1.4142"));
    }

    #[test]
    fn test_power() {
        assert_eq!(evaluate_expression("2^10", 0).unwrap(), "1024");
        assert_eq!(evaluate_expression("5^3", 0).unwrap(), "125");
    }

    #[test]
    fn test_large_numbers() {
        let expr = "123456789012345678901234567890 + 987654321098765432109876543210";
        assert_eq!(evaluate_expression(expr, 0).unwrap(), "1111111110111111111011111111100");
    }

    #[test]
    fn test_nested_parentheses() {
        assert_eq!(evaluate_expression("((2+3)*4)", 0).unwrap(), "20");
    }

    #[test]
    fn test_division_by_zero() {
        let result = evaluate_expression("1/0", 0);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("Division by zero"));
    }

    #[test]
    fn test_negative_sqrt() {
        let result = evaluate_expression("sqrt(-1)", 0);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("Square root of negative number"));
    }

    #[test]
    fn test_unmatched_parentheses() {
        // Такие выражения отклоняются на этапе валидации
        assert!(!validate_expression("(1+2"));
        assert!(!validate_expression("1+2)"));
    }

    #[test]
    fn test_exponent_too_large() {
        let result = evaluate_expression(&format!("2^{}", MAX_EXPONENT + 1), 0);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("Exponent too large"));
    }

    #[test]
    fn test_precision_limit() {
        let result = evaluate_expression("1/3", MAX_PRECISION + 1);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("Precision too high"));
    }

    #[test]
    fn test_fractional_input() {
        assert_eq!(evaluate_expression("0.1 + 0.2", 1).unwrap(), "0.3");
    }

    #[test]
    fn test_unary_minus() {
        assert_eq!(evaluate_expression("-5 + 10", 0).unwrap(), "5");
    }

    #[test]
    fn test_invalid_function() {
        // Теперь "abc" отклоняется на валидации, так что evaluate_expression
        // вернёт "Invalid expression", а не "Unexpected character"
        let result = evaluate_expression("1+2*abc", 0);
        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), "Invalid expression");
    }
}