"""
Wildfire Burn Severity & Vegetation Impact Analysis
Sierra Ridge Mutual Water Company — Service Area Brief
Uses Sentinel-2 L2A via Microsoft Planetary Computer STAC
"""

import json
import warnings
import numpy as np
import geopandas as gpd
import xarray as xr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.patches import Patch
import matplotlib.ticker as mticker
from shapely.geometry import shape, mapping
import pystac_client
import planetary_computer
import odc.stac

warnings.filterwarnings("ignore")

# ─── Config ──────────────────────────────────────────────────────────────
GEOJSON_PATH = "service-area.geojson"
OUTPUT_DIR = "."
TARGET_CRS = "EPSG:32610"  # UTM Zone 10N for Redding area
CLOUD_COVER_MAX = 15
RESOLUTION = 20  # metres — matches B11/B12 native res

# dNBR severity thresholds (USGS / Key & Benson 2006)
SEVERITY_CLASSES = {
    "Enhanced Regrowth (High)": (-0.500, -0.251),
    "Enhanced Regrowth (Low)":  (-0.250, -0.101),
    "Unburned":                 (-0.100,  0.099),
    "Low Severity":             ( 0.100,  0.269),
    "Moderate-Low Severity":    ( 0.270,  0.439),
    "Moderate-High Severity":   ( 0.440,  0.659),
    "High Severity":            ( 0.660,  1.300),
}

SEVERITY_COLORS = [
    "#1a9641",  # enhanced regrowth high
    "#a6d96a",  # enhanced regrowth low
    "#f7f7f7",  # unburned
    "#fee08b",  # low
    "#fdae61",  # moderate-low
    "#f46d43",  # moderate-high
    "#d73027",  # high
]

NDVI_CMAP = "RdYlGn"

# ─── 1. Load service area ────────────────────────────────────────────────
print("Loading service area...")
aoi = gpd.read_file(GEOJSON_PATH)
aoi_4326 = aoi.to_crs("EPSG:4326")
aoi_utm = aoi.to_crs(TARGET_CRS)
bounds = aoi_4326.total_bounds  # [minx, miny, maxx, maxy]
bbox = list(bounds)

total_area_km2 = aoi_utm.geometry.area.sum() / 1e6
print(f"  Service area: {total_area_km2:.1f} km2")
print(f"  Bounding box: {bbox}")

# ─── 2. Query Sentinel-2 from Planetary Computer ────────────────────────
print("\nSearching Planetary Computer STAC catalog...")
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

# Post-fire image: recent cloud-free scene
post_search = catalog.search(
    collections=["sentinel-2-l2a"],
    bbox=bbox,
    datetime="2025-06-01/2025-10-31",
    query={"eo:cloud_cover": {"lt": CLOUD_COVER_MAX}},
    sortby=[{"field": "datetime", "direction": "desc"}],
    max_items=10,
)
post_items = list(post_search.items())
print(f"  Post-fire candidates: {len(post_items)} scenes")

# Pre-fire image: same season one year earlier for dNBR
if post_items:
    post_item = post_items[0]
    post_date = post_item.datetime
    print(f"  Selected post-fire scene: {post_item.id} ({post_date.strftime('%Y-%m-%d')})")

    pre_year = post_date.year - 1
    pre_search = catalog.search(
        collections=["sentinel-2-l2a"],
        bbox=bbox,
        datetime=f"{pre_year}-{post_date.month-1:02d}-01/{pre_year}-{post_date.month+1:02d}-28",
        query={"eo:cloud_cover": {"lt": CLOUD_COVER_MAX}},
        sortby=[{"field": "datetime", "direction": "desc"}],
        max_items=10,
    )
    pre_items = list(pre_search.items())
    pre_item = pre_items[0]
    pre_date = pre_item.datetime
    print(f"  Selected pre-fire scene:  {pre_item.id} ({pre_date.strftime('%Y-%m-%d')})")
else:
    raise RuntimeError("No suitable post-fire Sentinel-2 scenes found.")

# ─── 3. Load bands via odc-stac ─────────────────────────────────────────
print("\nLoading Sentinel-2 bands (B04, B08, B11, B12, SCL)...")
bands = ["B04", "B08", "B11", "B12", "SCL"]

ds_post = odc.stac.load(
    [post_item], bands=bands, resolution=RESOLUTION,
    crs=TARGET_CRS, bbox=bbox, chunks={}
).isel(time=0)

ds_pre = odc.stac.load(
    [pre_item], bands=bands, resolution=RESOLUTION,
    crs=TARGET_CRS, bbox=bbox, chunks={}
).isel(time=0)

print("  Loaded. Computing...")

# ─── 4. Cloud masking (SCL-based) ───────────────────────────────────────
def cloud_mask(ds):
    """Mask clouds, shadows, cirrus using SCL band."""
    bad = [0, 1, 3, 8, 9, 10]  # no data, defective, shadow, cloud med/high, cirrus
    mask = ~ds["SCL"].isin(bad)
    return ds.where(mask)

ds_post = cloud_mask(ds_post)
ds_pre = cloud_mask(ds_pre)

# ─── 5. Compute indices ─────────────────────────────────────────────────
def calc_nbr(ds):
    nir = ds["B08"].astype(float)
    swir = ds["B12"].astype(float)
    return (nir - swir) / (nir + swir + 1e-8)

def calc_ndvi(ds):
    nir = ds["B08"].astype(float)
    red = ds["B04"].astype(float)
    return (nir - red) / (nir + red + 1e-8)

nbr_pre = calc_nbr(ds_pre).compute()
nbr_post = calc_nbr(ds_post).compute()
dnbr = (nbr_pre - nbr_post).compute()
ndvi_post = calc_ndvi(ds_post).compute()

print("  NBR pre/post, dNBR, and NDVI computed.")

# ─── 6. Clip to service area ────────────────────────────────────────────
print("\nClipping to service area polygon...")
aoi_utm_geom = aoi_utm.geometry.values[0]

# Build a boolean mask from the AOI polygon
from rasterio.features import geometry_mask
transform = dnbr.odc.transform
out_shape = (dnbr.sizes["y"], dnbr.sizes["x"])
aoi_mask = ~geometry_mask(
    [mapping(aoi_utm_geom)], out_shape=out_shape,
    transform=transform, all_touched=True
)
aoi_mask_da = xr.DataArray(aoi_mask, dims=["y", "x"], coords={"y": dnbr.y, "x": dnbr.x})

dnbr_clipped = dnbr.where(aoi_mask_da)
ndvi_clipped = ndvi_post.where(aoi_mask_da)

# ─── 7. Classify burn severity ──────────────────────────────────────────
print("Classifying burn severity...")
severity = np.full(dnbr_clipped.shape, np.nan)
dnbr_vals = dnbr_clipped.values

thresholds = list(SEVERITY_CLASSES.values())
for i, (lo, hi) in enumerate(thresholds):
    severity[(dnbr_vals >= lo) & (dnbr_vals <= hi)] = i + 1

severity_da = xr.DataArray(severity, dims=["y", "x"], coords={"y": dnbr.y, "x": dnbr.x})

# ─── 8. Compute area statistics ─────────────────────────────────────────
print("Computing area statistics...\n")
pixel_area_m2 = RESOLUTION * RESOLUTION
pixel_area_km2 = pixel_area_m2 / 1e6

stats_lines = []
stats_lines.append(f"{'Severity Class':<30s} {'Pixels':>8s} {'Area (km2)':>12s} {'% of AOI':>10s}")
stats_lines.append("-" * 64)

class_names = list(SEVERITY_CLASSES.keys())
total_valid = 0
class_areas = {}
for i, name in enumerate(class_names, 1):
    count = int(np.nansum(severity == i))
    area = count * pixel_area_km2
    class_areas[name] = area
    total_valid += count

for i, name in enumerate(class_names, 1):
    count = int(np.nansum(severity == i))
    area = count * pixel_area_km2
    pct = (count / total_valid * 100) if total_valid > 0 else 0
    stats_lines.append(f"{name:<30s} {count:>8d} {area:>12.3f} {pct:>9.1f}%")

stats_lines.append("-" * 64)
stats_lines.append(f"{'TOTAL VALID PIXELS':<30s} {total_valid:>8d} {total_valid * pixel_area_km2:>12.3f} {'100.0%':>10s}")

stats_text = "\n".join(stats_lines)
print(stats_text)

# NDVI summary
ndvi_vals = ndvi_clipped.values[~np.isnan(ndvi_clipped.values)]
ndvi_stats = {
    "mean": float(np.mean(ndvi_vals)),
    "median": float(np.median(ndvi_vals)),
    "std": float(np.std(ndvi_vals)),
    "min": float(np.min(ndvi_vals)),
    "max": float(np.max(ndvi_vals)),
    "pct_stressed": float(np.sum(ndvi_vals < 0.3) / len(ndvi_vals) * 100),
    "pct_healthy": float(np.sum(ndvi_vals > 0.5) / len(ndvi_vals) * 100),
}
print(f"\nNDVI Summary (within service area):")
print(f"  Mean:  {ndvi_stats['mean']:.3f}")
print(f"  Median: {ndvi_stats['median']:.3f}")
print(f"  Std:   {ndvi_stats['std']:.3f}")
print(f"  Stressed vegetation (NDVI < 0.3): {ndvi_stats['pct_stressed']:.1f}%")
print(f"  Healthy vegetation (NDVI > 0.5):  {ndvi_stats['pct_healthy']:.1f}%")

# ─── 9. Generate maps ───────────────────────────────────────────────────
print("\nGenerating maps...")

fig, axes = plt.subplots(1, 2, figsize=(18, 8))
fig.suptitle(
    "Sierra Ridge Mutual Water Company — Wildfire Impact Assessment\n"
    f"Sentinel-2 Analysis | Pre: {pre_date.strftime('%Y-%m-%d')} / Post: {post_date.strftime('%Y-%m-%d')}",
    fontsize=14, fontweight="bold", y=0.98
)

# --- Burn severity map ---
ax1 = axes[0]
cmap_sev = ListedColormap(SEVERITY_COLORS)
norm_sev = BoundaryNorm([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5], cmap_sev.N)

im1 = ax1.imshow(
    severity, cmap=cmap_sev, norm=norm_sev,
    extent=[
        float(dnbr.x.min()), float(dnbr.x.max()),
        float(dnbr.y.min()), float(dnbr.y.max())
    ],
    origin="upper"
)

# Overlay AOI boundary
aoi_utm.boundary.plot(ax=ax1, color="black", linewidth=1.5)

ax1.set_title("Burn Severity (dNBR Classification)", fontsize=12)
ax1.set_xlabel("Easting (m)")
ax1.set_ylabel("Northing (m)")
ax1.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1000:.0f}k"))
ax1.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1000:.0f}k"))

legend_patches = [Patch(facecolor=c, label=n) for n, c in zip(class_names, SEVERITY_COLORS)]
ax1.legend(handles=legend_patches, loc="lower left", fontsize=7, framealpha=0.9)

# --- NDVI map ---
ax2 = axes[1]
vmin, vmax = 0.0, 0.8
im2 = ax2.imshow(
    ndvi_clipped.values, cmap=NDVI_CMAP, vmin=vmin, vmax=vmax,
    extent=[
        float(ndvi_clipped.x.min()), float(ndvi_clipped.x.max()),
        float(ndvi_clipped.y.min()), float(ndvi_clipped.y.max())
    ],
    origin="upper"
)
aoi_utm.boundary.plot(ax=ax2, color="black", linewidth=1.5)
ax2.set_title("Vegetation Health (NDVI)", fontsize=12)
ax2.set_xlabel("Easting (m)")
ax2.set_ylabel("Northing (m)")
ax2.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1000:.0f}k"))
ax2.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1000:.0f}k"))
plt.colorbar(im2, ax=ax2, shrink=0.7, label="NDVI")

plt.tight_layout(rect=[0, 0, 1, 0.94])
map_path = f"{OUTPUT_DIR}/wildfire_assessment_maps.png"
plt.savefig(map_path, dpi=200, bbox_inches="tight")
plt.close()
print(f"  Maps saved: {map_path}")

# ─── 10. NDVI histogram ─────────────────────────────────────────────────
fig2, ax3 = plt.subplots(figsize=(8, 4))
ax3.hist(ndvi_vals, bins=60, color="#2ca02c", edgecolor="white", alpha=0.85)
ax3.axvline(0.3, color="red", linestyle="--", label="Stress threshold (0.3)")
ax3.axvline(0.5, color="green", linestyle="--", label="Healthy threshold (0.5)")
ax3.set_xlabel("NDVI")
ax3.set_ylabel("Pixel Count")
ax3.set_title("Vegetation Health Distribution — Service Area")
ax3.legend()
plt.tight_layout()
hist_path = f"{OUTPUT_DIR}/ndvi_histogram.png"
plt.savefig(hist_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"  Histogram saved: {hist_path}")

# ─── 11. Write methods summary / brief ──────────────────────────────────
# Compute burn severity area breakdown for the brief
burn_area_low = class_areas.get("Low Severity", 0)
burn_area_mod_low = class_areas.get("Moderate-Low Severity", 0)
burn_area_mod_high = class_areas.get("Moderate-High Severity", 0)
burn_area_high = class_areas.get("High Severity", 0)
burn_area_total = burn_area_low + burn_area_mod_low + burn_area_mod_high + burn_area_high
regrowth_area = class_areas.get("Enhanced Regrowth (High)", 0) + class_areas.get("Enhanced Regrowth (Low)", 0)

brief = f"""# Wildfire Burn Severity & Vegetation Impact Brief
## Sierra Ridge Mutual Water Company — Service Area Assessment

**Prepared:** 2026-04-12
**Meeting:** Dana Morales, Operations Manager — 4:30 PM Pacific

---

## Key Findings

| Metric | Value |
|--------|-------|
| Total service area analyzed | {total_area_km2:.1f} km2 |
| Area with active burn signatures | {burn_area_total:.2f} km2 ({burn_area_total/total_area_km2*100:.1f}%) |
| Area showing regrowth | {regrowth_area:.2f} km2 ({regrowth_area/total_area_km2*100:.1f}%) |
| Mean NDVI (vegetation health) | {ndvi_stats['mean']:.3f} |
| Vegetation under stress (NDVI < 0.3) | {ndvi_stats['pct_stressed']:.1f}% of area |
| Healthy vegetation (NDVI > 0.5) | {ndvi_stats['pct_healthy']:.1f}% of area |

## Burn Severity Breakdown

{stats_text}

## Imagery Details

| Parameter | Pre-fire Scene | Post-fire Scene |
|-----------|---------------|-----------------|
| Scene ID | `{pre_item.id}` | `{post_item.id}` |
| Date | {pre_date.strftime('%Y-%m-%d')} | {post_date.strftime('%Y-%m-%d')} |
| Source | Sentinel-2 L2A (Copernicus) | Sentinel-2 L2A (Copernicus) |
| Resolution | {RESOLUTION}m | {RESOLUTION}m |

## Methods Summary

This analysis uses **Sentinel-2 Level-2A** surface reflectance imagery accessed from the
Microsoft Planetary Computer STAC catalog. The workflow:

1. **Scene selection** — Cloud-free (<{CLOUD_COVER_MAX}%) Sentinel-2 scenes were selected for
   a post-fire period (summer/fall 2025) and a pre-fire baseline (same season, prior year).

2. **Cloud masking** — The Scene Classification Layer (SCL) was used to mask clouds, cloud shadows,
   and cirrus pixels.

3. **Burn severity (dNBR)** — The Normalized Burn Ratio (NBR) was calculated from NIR (B08) and
   SWIR (B12) bands for both dates. The differenced NBR (dNBR = NBR_pre - NBR_post) was classified
   into USGS burn severity categories (Key & Benson, 2006).

4. **Vegetation health (NDVI)** — The Normalized Difference Vegetation Index was calculated from
   NIR (B08) and Red (B04) bands of the post-fire scene to assess current vegetation condition.

5. **Spatial clipping** — All rasters were clipped to the Sierra Ridge service area polygon and
   reprojected to UTM Zone 10N (EPSG:32610) for accurate area calculations.

6. **Area statistics** — Pixel counts per severity class were converted to km2 using the
   {RESOLUTION}m pixel resolution.

## Attachments

- `wildfire_assessment_maps.png` — Side-by-side burn severity and NDVI maps
- `ndvi_histogram.png` — Distribution of vegetation health values across the service area
- `service-area.geojson` — Prospect service area boundary

---
*Analysis performed using Sentinel-2 L2A imagery via Microsoft Planetary Computer.
Classification follows USGS dNBR severity thresholds (Key & Benson, 2006).*
"""

brief_path = f"{OUTPUT_DIR}/wildfire_brief.md"
with open(brief_path, "w") as f:
    f.write(brief)
print(f"\n  Brief saved: {brief_path}")
print("\nDone. All outputs generated.")
