Skip to content

Torch Utils

Source

Convenience functions using torch.

These functions are used to control certain aspects and behaviours of PyTorch.

filter_parameters(func, user_params)

Filter parameters according to the function signature.

Parameters:

Name Type Description Default
func type

Class object.

required
user_params dict

User provided parameters.

required

Returns:

Type Description
dict

Parameters matching func's signature.

get_device()

Get the device on which operations take place.

Returns:

Type Description
str

The device on which operations take place, e.g. "cuda", "cpu" or "mps".

get_optimizer(name)

Return the optimizer class given its name.

Parameters:

Name Type Description Default
name str

Optimizer name.

required

Returns:

Type Description
Optimizer

Optimizer class.

get_optimizers()

Return the list of all optimizers available in torch.optim.

Returns:

Type Description
dict

Optimizers available in torch.optim.

get_scheduler(name)

Return the scheduler class given its name.

Parameters:

Name Type Description Default
name str

Scheduler name.

required

Returns:

Type Description
Union

Scheduler class.

get_schedulers()

Return the list of all schedulers available in torch.optim.lr_scheduler.

Returns:

Type Description
dict

Schedulers available in torch.optim.lr_scheduler.