ClassBalancedCrossEntropyLoss#

pydantic model vision_architectures.losses.class_balanced_cross_entropy_loss.ClassBalancedCrossEntropyLossConfig[source]#

Bases: CustomBaseModel

Show JSON schema
{
   "title": "ClassBalancedCrossEntropyLossConfig",
   "type": "object",
   "properties": {
      "num_classes": {
         "description": "Number of classes to weight cross entropy loss.",
         "title": "Num Classes",
         "type": "integer"
      },
      "ema_decay": {
         "default": 0.99,
         "description": "Exponential moving average decay. By default 0.99 is used which has a half life of ~69 steps",
         "title": "Ema Decay",
         "type": "number"
      }
   },
   "required": [
      "num_classes"
   ]
}

Config:
  • arbitrary_types_allowed: bool = True

  • extra: str = ignore

  • validate_default: bool = True

  • validate_assignment: bool = True

  • validate_return: bool = True

Fields:
Validators:

field num_classes: int [Required]#

Number of classes to weight cross entropy loss.

Validated by:
field ema_decay: float = 0.99#

Exponential moving average decay. By default 0.99 is used which has a half life of ~69 steps

Validated by:
class vision_architectures.losses.class_balanced_cross_entropy_loss.ClassBalancedCrossEntropyLoss(config={}, **kwargs)[source]#

Bases: Module

__init__(config={}, **kwargs)[source]#

Class-balanced cross-entropy loss with running prevalence estimation.

This loss reweights the standard multi-class cross-entropy by the inverse of the observed class prevalences in the training data. Class prevalences are estimated online via an exponential moving average (EMA) using class counts from the incoming targets.

Notes

  • Targets must be discrete integer class indices in [0, num_classes-1]. Probabilistic/soft labels are not supported.

  • For classes that haven’t been observed yet, their prevalence is treated as NaN and replaced in the weight vector by the mean of observed weights (then clamped within 3 standard deviations to avoid extreme values).

Parameters:
  • config (ClassBalancedCrossEntropyLossConfig) – An instance of the Config class that contains all the configuration parameters. It can also be passed as a dictionary and the instance will be created automatically.

  • **kwargs – Additional keyword arguments for configuration.

update_class_prevalences(target, ignore_index=-100)#

Update the running class-prevalence estimates from a target tensor.

The method counts class occurrences in the provided target, converts counts to per-batch prevalences, then updates the internal EMA-tracked prevalence vector for each class.

Parameters:
  • target (Tensor) – A tensor of integer class indices with any shape, typically (N,) or (N, …) for segmentation. Values outside [0, num_classes-1] are ignored.

  • ignore_index (int) – A class index to ignore during updates.

get_class_prevalences(device=device(type='cpu'))#

Return the current vector of class prevalences as a tensor.

For classes that haven’t been observed yet, the corresponding entry will be NaN. This method does not perform any imputation or normalization beyond returning the current EMA state.

Parameters:

device – The device on which to place the returned tensor.

Return type:

Tensor

Returns:

Tensor of shape (num_classes,) with dtype float32 containing per-class prevalence estimates in [0, 1] or NaN for unseen classes.

get_class_weights(device=device(type='cpu'))#

Compute per-class weights as the inverse of prevalences with safeguards.

Steps:
  1. Convert current EMA prevalences to a tensor with NaNs for unseen classes.

  2. Take the inverse to obtain raw weights (higher weight for rarer classes).

  3. Replace NaNs by the mean of observed weights to avoid biasing toward unseen classes, then clamp to mean ± 3·std to prevent extreme values.

  4. Renormalize weights to sum to num_classes (so the average weight is 1).

Parameters:

device – The device on which to place the returned tensor.

Return type:

Tensor

Returns:

Tensor of shape (num_classes,) containing normalized class weights.

forward(input, target, update_class_prevalences=True, return_class_weights=False, *args, **kwargs)[source]#

Compute class-balanced cross entropy.

This updates the internal class-prevalence EMA using the provided targets, then computes cross-entropy with a weight vector derived from the current prevalences.

Parameters:
  • input (Tensor) – Logits of shape (N, C, …) where C == num_classes. Any extra spatial dimensions (e.g., H, W, D) are supported as long as target is broadcastable to the same non-channel shape expected by torch.nn.functional.cross_entropy.

  • target (Tensor) – Integer class indices with shape matching input without the channel dimension, e.g., (N, …) with values in [0, C-1].

  • update_class_prevalences (bool) – If True, update the internal class prevalence estimates using the provided targets.

  • return_class_weights (bool) – If True, also return the per-class weight tensor used for this call.

  • *args – Additional keyword args forwarded to F.cross_entropy (e.g., reduction=’mean’).

  • **kwargs

    Additional keyword args forwarded to F.cross_entropy (e.g., reduction=’mean’).

Returns:

a scalar tensor loss. If True: a tuple (loss, class_weights).

Return type:

If return_class_weights is False