Remote sensing with satellite images#

In this example we demonstrate remote sensing inverse problems for multispectral satellite imaging. We will focus on pan-sharpening, i.e., recovering high-resolution multispectral images from measurement pairs of low-resolution multispectral images and high-resolution panchromatic (single-band) images with the forward operator deepinv.physics.Pansharpen.

These have important applications for image restoration in environmental monitoring, urban planning, disaster recovery etc.

We provide a convenient satellite image dataset for pan-sharpening deepinv.datasets.NBUDataset provided in the paper A Large-Scale Benchmark Data Set for Evaluating Pansharpening Performance which includes data from several satellites such as WorldView satellites.

For remote sensing experiments, DeepInverse provides the following:

import deepinv as dinv
import torch

Load raw pan-sharpening measurements#

The dataset includes raw pansharpening measurements containing (MS, PAN) where MS are the low-res (4-band) multispectral and PAN are the high-res panchromatic images. Note there are no ground truth images!

Note

The pan-sharpening measurements are provided as a deepinv.utils.TensorList, since the pan-sharpening physics deepinv.physics.Pansharpen is a stacked physics combining deepinv.physics.Downsampling and deepinv.physics.Decolorize. See the User Guide Combining Physics for more information.

Note, for plotting purposes we only plot the first 3 bands (RGB).

Note also that the linear adjoint must assume the unknown spectral response function (SRF).

DATA_DIR = dinv.utils.get_data_home()
dataset = dinv.datasets.NBUDataset(DATA_DIR, return_pan=True, download=True)

y = dataset[0].unsqueeze(0)  # MS (1,4,256,256), PAN (1,1,1024,1024)

physics = dinv.physics.Pansharpen((4, 1024, 1024), factor=4)

# Pansharpen with classical Brovey method
x_hat = physics.A_dagger(y)  # shape (1,4,1024,1024)

dinv.utils.plot(
    [
        y[0][:, :3],
        y[1],  # Note this will be interpolated to match high-res image size
        x_hat[:, :3],
        physics.A_adjoint(y)[:, :3],
    ],
    titles=[
        "Input MS",
        "Input PAN",
        "Pseudo-inverse using Brovey method",
        "Linear adjoint",
    ],
    dpi=1200,
)

# Evaluate performance - note we can only use QNR as we have no GT
qnr = dinv.metric.QNR()
print(qnr(x_net=x_hat, x=None, y=y, physics=physics))
Input MS, Input PAN, Pseudo-inverse using Brovey method, Linear adjoint
Downloading datasets/nbu/gaofen-1.zip

  0%|          | 0/7914941 [00:00<?, ?it/s]
 14%|█▍        | 1.06M/7.55M [00:00<00:00, 10.7MB/s]
 28%|██▊       | 2.12M/7.55M [00:00<00:00, 10.9MB/s]
 42%|████▏     | 3.19M/7.55M [00:00<00:00, 11.0MB/s]
 56%|█████▋    | 4.25M/7.55M [00:00<00:00, 11.0MB/s]
 70%|███████   | 5.31M/7.55M [00:00<00:00, 10.2MB/s]
 84%|████████▎ | 6.31M/7.55M [00:00<00:00, 10.2MB/s]
 98%|█████████▊| 7.38M/7.55M [00:00<00:00, 10.5MB/s]
100%|██████████| 7.55M/7.55M [00:00<00:00, 10.8MB/s]

Extracting:   0%|          | 0/12 [00:00<?, ?it/s]
Extracting: 100%|██████████| 12/12 [00:00<00:00, 593.00it/s]
Dataset has been successfully downloaded.
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:70: FutureWarning: Importing `spectral_angle_mapper` from `torchmetrics.functional` was deprecated and will be removed in 2.0. Import `spectral_angle_mapper` from `torchmetrics.image` instead.
  _future_warning(
tensor([0.4133])

Simulate pan-sharpening measurements#

We can also simulate pan-sharpening measurements so that we have pairs of measurements and ground truth. Now, the dataset loads ground truth images x. For the pansharpening physics, we assume a flat spectral response function, but this can also be jointly learned. We simulate Gaussian noise on the panchromatic images.

dataset = dinv.datasets.NBUDataset(DATA_DIR, return_pan=False)

x = dataset[0].unsqueeze(0)  # just MS of shape 1,4,256,256

physics = dinv.physics.Pansharpen((4, 256, 256), factor=4, srf="flat")

y = physics(x)

# Pansharpen with classical Brovey method
x_hat = physics.A_dagger(y)

Solving pan-sharpening with neural networks#

The pan-sharpening physics is compatible with the rest of the DeepInverse library so we can solve the inverse problem using any method provided in the library. For example, we use here the PanNet model.

This model can be trained using losses such as supervised learning using deepinv.loss.SupLoss or self-supervised learning using Equivariant Imaging deepinv.loss.EILoss, which was applied to pan-sharpening in Wang et al., Perspective-Equivariant Imaging: an Unsupervised Framework for Multispectral Pansharpening

For evaluation, we use the standard full-reference metrics (ERGAS, SAM) and no-reference (QNR).

Note

This is a tiny example using 5 images. We demonstrate training for 1 epoch for speed, but you can train from scratch using 50 epochs.

model = dinv.models.PanNet(hrms_shape=(4, 256, 256))
x_net = model(y, physics)

# Example training loss using measurement consistency on the multispectral images
# and Stein's Unbiased Risk Estimate on the panchromatic images.
loss = dinv.loss.StackedPhysicsLoss(
    [dinv.loss.MCLoss(), dinv.loss.SureGaussianLoss(0.05)]
)

# Evaluate performance when ground-truth available
sam = dinv.metric.distortion.SpectralAngleMapper()
ergas = dinv.metric.distortion.ERGAS(factor=4)
qnr = dinv.metric.QNR()
print(sam(x_hat, x), ergas(x_hat, x), qnr(x_hat, x=None, y=y, physics=physics))

# Load optimizer and pretrained model
optimizer = torch.optim.Adam(model.parameters())

from deepinv.models.utils import get_weights_url

file_name = "demo_nbu_pansharpen.pth"
url = get_weights_url(model_name="demo", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
    url, map_location=lambda storage, loc: storage, file_name=file_name
)
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])

# Train using deepinv Trainer
from torch.utils.data import DataLoader

trainer = dinv.Trainer(
    model=model,
    physics=physics,
    optimizer=optimizer,
    losses=loss,
    metrics=[sam, ergas],
    train_dataloader=DataLoader(dataset),
    epochs=1,
    online_measurements=True,
    plot_images=False,
    compare_no_learning=True,
    no_learning_method="A_dagger",
    show_progress_bar=False,
)

trainer.train()
trainer.test(DataLoader(dataset))

# Plot results
dinv.utils.plot(
    [
        x[:, :3],
        y[0][:, :3],
        y[1],
        x_hat[:, :3],
        x_net[:, :3],
    ],
    titles=["x HRMS", "y LRMS", "y PAN", "Estimate (classical)", "Estimate (PanNet)"],
)
x HRMS, y LRMS, y PAN, Estimate (classical), Estimate (PanNet)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:70: FutureWarning: Importing `error_relative_global_dimensionless_synthesis` from `torchmetrics.functional` was deprecated and will be removed in 2.0. Import `error_relative_global_dimensionless_synthesis` from `torchmetrics.image` instead.
  _future_warning(
tensor([0.0545]) tensor([4.1428]) tensor([0.4736])
Downloading: "https://huggingface.co/deepinv/demo/resolve/main/demo_nbu_pansharpen.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_nbu_pansharpen.pth

  0%|          | 0.00/955k [00:00<?, ?B/s]
100%|██████████| 955k/955k [00:00<00:00, 11.1MB/s]
The model has 77124 trainable parameters
Train epoch 0: TotalLoss=0.002, SpectralAngleMapper=0.28, ERGAS=23.228
Eval epoch 0: SpectralAngleMapper=0.275, SpectralAngleMapper no learning=0.039, ERGAS=24.077, ERGAS no learning=4.343
Test results:
SpectralAngleMapper no learning: 0.039 +- 0.011
SpectralAngleMapper: 0.275 +- 0.009
ERGAS no learning: 4.343 +- 1.165
ERGAS: 24.077 +- 9.450

Total running time of the script: (0 minutes 8.299 seconds)

Gallery generated by Sphinx-Gallery