"""
WFIS Market Analyzer - Core Analytics Engine
Quantitative Liquidity Intelligence
"""

import numpy as np
import pandas as pd
from scipy import stats
from scipy.signal import find_peaks
from scipy.stats import gaussian_kde
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
from arch import arch_model
import warnings
warnings.filterwarnings('ignore')

class MarketAnalyzer:
    """
    Core quantitative analysis engine for WFIS platform.
    Implements:
    - EGARCH volatility modeling
    - Monte Carlo simulation with t-Student distribution
    - Liquidity zone detection via KDE
    - Anomaly detection with Isolation Forest
    - VaR/CVaR calculation
    """
    
    def __init__(self, price_data, volume_data=None):
        self.price_data = price_data
        self.volume_data = volume_data
        self.returns = np.log(price_data / price_data.shift(1)).dropna()
        self._cache = {}
        
    def detect_liquidity_zones(self, bandwidth=0.02, n_points=500):
        """
        Detect liquidity zones using KDE with volume weighting.
        
        Returns:
            List of zones with low/high boundaries and strength scores
        """
        cache_key = f"liquidity_zones_{bandwidth}"
        if cache_key in self._cache:
            return self._cache[cache_key]
        
        prices = self.price_data.values
        volumes = self.volume_data.values if self.volume_data is not None else np.ones(len(prices))
        
        # Volume-weighted KDE
        weights = volumes / volumes.sum()
        kde = gaussian_kde(prices, bw_method=bandwidth, weights=weights)
        price_grid = np.linspace(prices.min(), prices.max(), n_points)
        density = kde(price_grid)
        
        # Find peaks
        peaks, properties = find_peaks(density, height=np.percentile(density, 70))
        
        liquidity_zones = []
        for peak in peaks[:5]:
            peak_price = price_grid[peak]
            peak_height = density[peak]
            
            # Find width at half height
            half_height = peak_height / 2
            left = peak
            right = peak
            while left > 0 and density[left] > half_height:
                left -= 1
            while right < len(density) - 1 and density[right] > half_height:
                right += 1
            
            # Calculate zone importance score
            zone_volume = volumes[(prices >= price_grid[left]) & (prices <= price_grid[right])].sum()
            total_volume = volumes.sum()
            volume_importance = zone_volume / total_volume if total_volume > 0 else 0
            
            liquidity_zones.append({
                'level': float(peak_price),
                'low': float(price_grid[left]),
                'high': float(price_grid[right]),
                'center': float(peak_price),
                'strength': float(peak_height / density.max()),
                'volume_accumulated': float(zone_volume),
                'volume_importance': float(volume_importance)
            })
        
        liquidity_zones.sort(key=lambda x: x['strength'], reverse=True)
        self._cache[cache_key] = liquidity_zones
        return liquidity_zones
    
    def detect_anomalies(self, contamination=0.05):
        """
        Detect order flow anomalies using Isolation Forest.
        
        Returns:
            str: 'HIGH', 'MODERATE', or 'LOW'
        """
        if self.volume_data is None:
            return "LOW"
        
        df = pd.DataFrame({
            'price': self.price_data,
            'volume': self.volume_data
        })
        
        # Feature engineering
        df['volume_ma'] = df['volume'].rolling(20).mean()
        df['volume_ratio'] = df['volume'] / df['volume_ma']
        df['price_range'] = (df['price'] - df['price'].shift(1)).abs() / df['price']
        df['returns'] = df['price'].pct_change()
        
        features = df[['volume_ratio', 'price_range', 'returns']].fillna(0).values[-100:]
        
        if len(features) < 10:
            return "LOW"
        
        # Isolation Forest
        iso_forest = IsolationForest(contamination=contamination, random_state=42)
        predictions = iso_forest.fit_predict(features)
        
        anomaly_percentage = (predictions == -1).mean() * 100
        latest_volume_ratio = df['volume_ratio'].iloc[-1] if not pd.isna(df['volume_ratio'].iloc[-1]) else 1
        
        if anomaly_percentage > 15 or latest_volume_ratio > 2.5:
            return "HIGH"
        elif anomaly_percentage > 5 or latest_volume_ratio > 1.8:
            return "MODERATE"
        else:
            return "LOW"
    
    def analyze_volatility(self, model_type='EGARCH'):
        """
        Analyze volatility outlook using EGARCH model.
        
        Returns:
            dict: volatility metrics and outlook
        """
        # Fit EGARCH model
        try:
            model = arch_model(self.returns * 100, vol='EGARCH', p=1, q=1, o=1, dist='t')
            result = model.fit(disp='off')
            
            # Current volatility
            current_vol = self.returns.std() * np.sqrt(252) * 100
            
            # Historical percentile
            rolling_vol = self.returns.rolling(252).std().dropna() * np.sqrt(252) * 100
            if len(rolling_vol) > 0:
                percentile = (rolling_vol < current_vol).mean() * 100
            else:
                percentile = 50
            
            # Forecast
            forecast = result.forecast(horizon=48)
            forecast_vol = np.sqrt(forecast.variance.values[-1]) / 100
            
            if percentile > 70:
                outlook = "Elevated risk"
            elif percentile < 30:
                outlook = "Low risk"
            else:
                outlook = "Normal risk"
            
            return {
                'current_volatility': float(current_vol),
                'historical_percentile': float(percentile),
                'outlook': outlook,
                'forecast_volatility': forecast_vol.tolist() if isinstance(forecast_vol, np.ndarray) else float(forecast_vol),
                'model_fitted': True
            }
        except Exception as e:
            return {
                'outlook': "Unable to forecast",
                'model_fitted': False,
                'error': str(e)
            }
    
    def monte_carlo_simulation(self, n_paths=20000, n_steps=48, horizon_days=2):
        """
        Run Monte Carlo simulation with t-Student distribution.
        
        Returns:
            np.ndarray: Simulated price paths
        """
        params = self._estimate_distribution_params()
        S0 = self.price_data.iloc[-1]
        mu = params['mu']
        df = params['df']
        scale = params['scale']
        
        dt = horizon_days / n_steps
        
        # Generate t-Student shocks
        epsilon = np.random.standard_t(df, size=(n_paths, n_steps))
        
        # Apply skewness adjustment
        if params['skewness'] > 0.2:
            epsilon = epsilon + params['skewness'] * 0.3
        elif params['skewness'] < -0.2:
            epsilon = epsilon + params['skewness'] * 0.3
        
        # GBM with t-Student shocks
        log_returns = (mu - 0.5 * scale**2) * dt + scale * np.sqrt(dt) * epsilon
        
        # Simulate paths
        paths = np.zeros((n_paths, n_steps + 1))
        paths[:, 0] = S0
        
        for t in range(n_steps):
            paths[:, t+1] = paths[:, t] * np.exp(log_returns[:, t])
        
        return paths
    
    def _estimate_distribution_params(self):
        """Estimate parameters for t-Student distribution"""
        params = stats.t.fit(self.returns.dropna())
        df, loc, scale = params
        mu = self.returns.mean()
        skewness = self.returns.skew()
        
        return {
            'mu': mu,
            'df': df,
            'scale': scale,
            'skewness': skewness,
            'current_price': self.price_data.iloc[-1]
        }
    
    def extract_scenarios(self, paths, current_price):
        """
        Extract bullish/bearish/neutral scenarios from Monte Carlo paths.
        
        Returns:
            dict: Scenario probabilities and price levels
        """
        max_prices = np.max(paths, axis=1)
        min_prices = np.min(paths, axis=1)
        
        # Dynamic thresholds (4% for trigger, ~7.6% for target)
        bull_trigger = current_price * 1.04
        bull_target = current_price * 1.076
        bear_trigger = current_price * 0.986
        bear_target = current_price * 0.953
        
        bull_scenario = (max_prices > bull_trigger) & (max_prices >= bull_target)
        bear_scenario = (min_prices < bear_trigger) & (min_prices <= bear_target)
        neutral_scenario = (max_prices < bull_trigger) & (min_prices > bear_trigger)
        
        return {
            'probabilities': {
                'bullish': float(np.mean(bull_scenario)),
                'bearish': float(np.mean(bear_scenario)),
                'neutral': float(np.mean(neutral_scenario))
            },
            'levels': {
                'bull_trigger': float(bull_trigger),
                'bull_target': float(bull_target),
                'bear_trigger': float(bear_trigger),
                'bear_target': float(bear_target)
            }
        }
    
    def calculate_var(self, paths, confidence=0.95):
        """
        Calculate Value at Risk (VaR) and Conditional VaR.
        
        Returns:
            dict: VaR and CVaR metrics
        """
        current_price = self.price_data.iloc[-1]
        final_prices = paths[:, -1]
        
        # Calculate returns
        returns = (final_prices - current_price) / current_price
        losses = -returns
        
        var = np.percentile(losses, (1 - confidence) * 100)
        cvar = losses[losses >= var].mean() if np.any(losses >= var) else var
        
        return {
            'VaR_95': float(var) * 100,
            'CVaR_95': float(cvar) * 100,
            'max_loss': float(losses.max()) * 100,
            'max_gain': float(returns.max()) * 100,
            'expected_return': float(returns.mean()) * 100
        }
    
    def get_technical_indicators(self):
        """
        Calculate technical indicators for the asset.
        
        Returns:
            dict: Various technical indicators
        """
        prices = self.price_data
        
        # Moving averages
        ma_20 = prices.rolling(20).mean().iloc[-1] if len(prices) >= 20 else prices.iloc[-1]
        ma_50 = prices.rolling(50).mean().iloc[-1] if len(prices) >= 50 else prices.iloc[-1]
        ma_200 = prices.rolling(200).mean().iloc[-1] if len(prices) >= 200 else prices.iloc[-1]
        
        # RSI
        delta = prices.diff()
        gain = delta.where(delta > 0, 0).rolling(14).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
        rs = gain / loss
        rsi = 100 - (100 / (1 + rs)).iloc[-1] if loss.iloc[-1] != 0 else 50
        
        # MACD
        ema_12 = prices.ewm(span=12, adjust=False).mean()
        ema_26 = prices.ewm(span=26, adjust=False).mean()
        macd = ema_12 - ema_26
        macd_signal = macd.ewm(span=9, adjust=False).mean()
        macd_histogram = macd - macd_signal
        
        # Bollinger Bands
        bb_middle = prices.rolling(20).mean()
        bb_std = prices.rolling(20).std()
        bb_upper = bb_middle + (bb_std * 2)
        bb_lower = bb_middle - (bb_std * 2)
        
        # Volatility
        volatility = self.returns.std() * np.sqrt(252) * 100
        
        # Support/Resistance (recent)
        recent_high = prices.tail(20).max()
        recent_low = prices.tail(20).min()
        
        return {
            'current_price': float(prices.iloc[-1]),
            'ma_20': float(ma_20),
            'ma_50': float(ma_50),
            'ma_200': float(ma_200),
            'rsi': float(rsi),
            'macd': float(macd.iloc[-1]),
            'macd_signal': float(macd_signal.iloc[-1]),
            'macd_histogram': float(macd_histogram.iloc[-1]),
            'bb_upper': float(bb_upper.iloc[-1]) if not pd.isna(bb_upper.iloc[-1]) else prices.iloc[-1],
            'bb_middle': float(bb_middle.iloc[-1]) if not pd.isna(bb_middle.iloc[-1]) else prices.iloc[-1],
            'bb_lower': float(bb_lower.iloc[-1]) if not pd.isna(bb_lower.iloc[-1]) else prices.iloc[-1],
            'volatility_annual': float(volatility),
            'recent_high': float(recent_high),
            'recent_low': float(recent_low)
        }