ARNIQA¶
Module Interface¶
- class torchmetrics.image.arniqa.ARNIQA(regressor_dataset='koniq10k', reduction='mean', normalize=True, autocast=False, **kwargs)[source]¶
ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.
ARNIQA is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better.
The input image is expected to have shape
(N, 3, H, W). The image should be in the [0, 1] range if normalize is set toTrue, otherwise it should be normalized with the ImageNet mean and standard deviation.Note
Using this metric requires you to have
torchvisionpackage installed. Either install aspip install torchmetrics[image]orpip install torchvision.As input to
forwardandupdatethe metric accepts the following inputimg(Tensor): tensor with images of shape(N, 3, H, W)
As output of forward and compute the metric returns the following output
arniqa(Tensor): tensor with ARNIQA score. If reduction is set tonone, the output will have shape(N,), otherwise it will be a scalar tensor. Tensor values are in the [0, 1] range, where higher is better.
- Parameters:
img¶ – the input image
regressor_dataset¶ (
Literal['kadid10k','koniq10k']) – dataset used for training the regressor. Choose between [koniq10k,kadid10k].koniq10kcorresponds to the KonIQ-10k dataset, which consists of real-world images with authentic distortions.kadid10kcorresponds to the KADID-10k dataset, which consists of images with synthetically generated distortions.reduction¶ (
Literal['sum','mean','none']) – indicates how to reduce over the batch dimension. Choose between [sum,mean,none].normalize¶ (
bool) – by default this isTruemeaning that the input is expected to be in the [0, 1] range. If set toFalsewill instead expect input to be already normalized with the ImageNet mean and standard deviation.autocast¶ (
bool) – ifTrue, metric will convert model to mixed precision before running forward pass.kwargs¶ (
Any) – additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
torchvisionpackage is not installedValueError – If
regressor_datasetis not in ["kadid10k","koniq10k"]ValueError – If
reductionis not in ["sum","mean","none"]ValueError – If
normalizeis not a boolValueError – If the input image is not a valid image tensor with shape [N, 3, H, W].
ValueError – If the input image values are not in the [0, 1] range when
normalizeis set toTrue
Examples
>>> from torch import rand >>> from torchmetrics.image.arniqa import ARNIQA >>> img = rand(8, 3, 224, 224) >>> # Non-normalized input >>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=True) >>> metric(img) tensor(0.5308)
>>> from torch import rand >>> from torchmetrics.image.arniqa import ARNIQA >>> from torchvision.transforms import Normalize >>> img = rand(8, 3, 224, 224) >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) >>> # Normalized input >>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=False) >>> metric(img) tensor(0.5065)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union[Tensor,Sequence[Tensor],None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.arniqa import ARNIQA >>> metric = ARNIQA(regressor_dataset='koniq10k') >>> metric.update(torch.rand(8, 3, 224, 224)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.arniqa import ARNIQA >>> metric = ARNIQA(regressor_dataset='koniq10k') >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.rand(8, 3, 224, 224))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.arniqa(img, regressor_dataset='koniq10k', reduction='mean', normalize=True, autocast=False)[source]¶
ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.
ARNIQA is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better.
The input image is expected to have shape
(N, 3, H, W). The image should be in the [0, 1] range if normalize is set toTrue, otherwise it should be normalized with the ImageNet mean and standard deviation.Note
Using this metric requires you to have
torchvisionpackage installed. Either install aspip install torchmetrics[image]orpip install torchvision.- Parameters:
regressor_dataset¶ (
Literal['kadid10k','koniq10k']) – dataset used for training the regressor. Choose between [koniq10k,kadid10k].koniq10kcorresponds to the KonIQ-10k dataset, which consists of real-world images with authentic distortions.kadid10kcorresponds to the KADID-10k dataset, which consists of images with synthetically generated distortions.reduction¶ (
Literal['sum','mean','none']) – indicates how to reduce over the batch dimension. Choose between [sum,mean,none].normalize¶ (
bool) – by default this isTruemeaning that the input is expected to be in the [0, 1] range. If set toFalsewill instead expect input to be already normalized with the ImageNet mean and standard deviation.autocast¶ (
bool) – boolean indicating whether to use automatic mixed precision
- Return type:
- Returns:
A tensor in the [0, 1] range, where higher is better, representing the ARNIQA score of the input image. If reduction is set to
none, the output will have shape(N,), otherwise it will be a scalar tensor.- Raises:
ModuleNotFoundError – If
torchvisionpackage is not installedValueError – If
regressor_datasetis not in ["kadid10k","koniq10k"]ValueError – If
reductionis not in ["sum","mean","none"]ValueError – If
normalizeis not a boolValueError – If the input image is not a valid image tensor with shape [N, 3, H, W].
ValueError – If the input image values are not in the [0, 1] range when
normalizeis set toTrue
Examples
>>> from torch import rand >>> from torchmetrics.functional.image.arniqa import arniqa >>> img = rand(8, 3, 224, 224) >>> # Non-normalized input >>> arniqa(img, regressor_dataset='koniq10k', normalize=True) tensor(0.5308)
>>> from torch import rand >>> from torchmetrics.functional.image.arniqa import arniqa >>> from torchvision.transforms import Normalize >>> img = rand(8, 3, 224, 224) >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) >>> # Normalized input >>> arniqa(img, regressor_dataset='koniq10k', normalize=False) tensor(0.5065)