register_model
Custom model registration utilities.
clear_custom_models()
#
get_custom_model(name)
#
Get the custom model corresponding to name
from the registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name | str | Name of the model to retrieve. | required |
Returns:
Type | Description |
---|---|
Module | The requested model. |
Raises:
Type | Description |
---|---|
ValueError | If the model is not registered. |
Source code in src/careamics/config/architectures/register_model.py
register_model(name)
#
Decorator used to register a torch.nn.Module class with a given name
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name | str | Name of the model. | required |
Returns:
Type | Description |
---|---|
Callable | Function allowing to instantiate the wrapped Module class. |
Raises:
Type | Description |
---|---|
ValueError | If a model is already registered with that name. |
Examples:
@register_model(name="linear")
class LinearModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(ones(in_features, out_features))
self.bias = nn.Parameter(ones(out_features))
def forward(self, input):
return (input @ self.weight) + self.bias