Spectral Distortion Index¶
Module Interface¶
- class torchmetrics.image.SpectralDistortionIndex(p=1, reduction='elementwise_mean', **kwargs)[source]¶
Compute Spectral Distortion Index (SpectralDistortionIndex) also now as D_lambda.
The metric is used to compare the spectral distortion between two images.
As input to
forwardandupdatethe metric accepts the following inputpreds(Tensor): Low resolution multispectral image of shape(N,C,H,W)target``(:class:`~torch.Tensor`): High resolution fused image of shape ``(N,C,H,W)
As output of forward and compute the metric returns the following output
sdi(Tensor): ifreduction!='none'returns float scalar tensor with average SDI value over sample else returns tensor of shape(N,)with SDI values per sample
- Parameters:
reduction¶ (
Literal['elementwise_mean','sum','none']) –a method to reduce metric score over labels.
'elementwise_mean': takes the mean (default)'sum': takes the sum'none': no reduction will be applied
kwargs¶ (
Any) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import rand >>> from torchmetrics.image import SpectralDistortionIndex >>> preds = rand([16, 3, 16, 16]) >>> target = rand([16, 3, 16, 16]) >>> sdi = SpectralDistortionIndex() >>> sdi(preds, target) tensor(0.0234)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union[Tensor,Sequence[Tensor],None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand >>> from torchmetrics.image import SpectralDistortionIndex >>> preds = rand([16, 3, 16, 16]) >>> target = rand([16, 3, 16, 16]) >>> metric = SpectralDistortionIndex() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand >>> from torchmetrics.image import SpectralDistortionIndex >>> preds = rand([16, 3, 16, 16]) >>> target = rand([16, 3, 16, 16]) >>> metric = SpectralDistortionIndex() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.spectral_distortion_index(preds, target, p=1, reduction='elementwise_mean')[source]¶
Calculate Spectral Distortion Index (SpectralDistortionIndex) also known as D_lambda.
Metric is used to compare the spectral distortion between two images.
- Parameters:
- Return type:
- Returns:
Tensor with SpectralDistortionIndex score
- Raises:
TypeError – If
predsandtargetdon’t have the same data type.ValueError – If
predsandtargetdon’t haveBxCxHxW shape.ValueError – If
pis not a positive integer.
Example
>>> from torch import rand >>> from torchmetrics.functional.image import spectral_distortion_index >>> preds = rand([16, 3, 16, 16]) >>> target = rand([16, 3, 16, 16]) >>> spectral_distortion_index(preds, target) tensor(0.0234)