homura.modules.functional package

Submodules

homura.modules.functional.attention module

homura.modules.functional.attention.kv_attention(query, key, value, mask=None, additive_mask=None, training=True, dropout_prob=0, scaling=True)[source]

Attention using queries, keys and value

Parameters
  • query (torch.Tensor) – …JxM

  • key (torch.Tensor) – …KxM

  • value (torch.Tensor) – …KxM

  • mask (Optional[torch.Tensor]) – …JxK

  • additive_mask (Optional[torch.Tensor]) –

  • training (bool) –

  • dropout_prob (float) –

  • scaling (bool) –

Returns

torch.Tensor whose shape of …JxM

Return type

tuple[torch.Tensor, torch.Tensor]

homura.modules.functional.discretizations module

homura.modules.functional.discretizations.gumbel_sigmoid(input, temp)[source]

gumbel sigmoid function

Parameters
  • input (torch.Tensor) –

  • temp (float) –

Return type

torch.Tensor

homura.modules.functional.discretizations.semantic_hashing(input, is_training)[source]

Semantic hashing

>>> semantic_hashing(torch.randn(3, 3), True) # by 0.5
tensor([[0.3515, 0.0918, 0.7717],
        [0.8246, 0.1620, 0.0689],
        [1.0000, 0.3575, 0.6598]])
>>> semantic_hashing(torch.randn(3, 3), False)
tensor([[0., 0., 1.],
        [0., 1., 1.],
        [0., 1., 1.]])
Parameters
  • input (torch.Tensor) –

  • is_training (bool) –

Return type

torch.Tensor

homura.modules.functional.discretizations.straight_through_estimator(input)[source]

straight through estimator

>>> straight_through_estimator(torch.randn(3, 3))
tensor([[0., 1., 0.],
        [0., 1., 1.],
        [0., 0., 1.]])
Parameters

input (torch.Tensor) –

Return type

torch.Tensor

homura.modules.functional.grad_approximation module

homura.modules.functional.grad_approximation.custom_straight_through_estimator(input_forward, input_backward)[source]
Parameters
  • input_forward (torch.Tensor) –

  • input_backward (torch.Tensor) –

Return type

torch.Tensor

homura.modules.functional.knn module

homura.modules.functional.knn.faiss_knn(keys, queries, num_neighbors, distance)[source]

k nearest neighbor using faiss. Users are recommended to use k_nearest_neighbor instead.

Parameters
  • keys (torch.Tensor) – tensor of (num_keys, dim)

  • queries (torch.Tensor) – tensor of (num_queries, dim)

  • num_neighbors (int) – k

  • distance (str) – user can use str or faiss.METRIC_*.

Returns

scores, indices in tensor

Return type

tuple[torch.Tensor, torch.Tensor]

homura.modules.functional.knn.k_nearest_neighbor(keys, queries, num_neighbors, distance, *, backend='torch')[source]

k-Nearest Neighbor search. Faiss backend requires GPU. torch backend is JITtable

Parameters
  • keys (torch.Tensor) – tensor of (num_keys, dim)

  • queries (torch.Tensor) – tensor of (num_queries, dim)

  • num_neighbors (int) – k

  • distance (str) – name of distance (inner_product or l2). Faiss backend additionally supports l1, linf, jansen_shannon.

  • backend (str) – backend (faiss or torch)

Returns

scores, indices

Return type

tuple[torch.Tensor, torch.Tensor]

homura.modules.functional.knn.torch_knn(keys, queries, num_neighbors, distance)[source]

k nearest neighbor using torch. Users are recommended to use k_nearest_neighbor instead.

Parameters
  • keys (torch.Tensor) –

  • queries (torch.Tensor) –

  • num_neighbors (int) –

  • distance (str) –

Return type

tuple[torch.Tensor, torch.Tensor]

homura.modules.functional.loss module

homura.modules.functional.loss.cross_entropy_with_smoothing(input, target, smoothing, dim=1, reduction='mean')[source]
Parameters
  • input (torch.Tensor) –

  • target (torch.Tensor) –

  • smoothing (float) –

  • dim (int) –

  • reduction (str) –

Returns

Return type

torch.Tensor

homura.modules.functional.loss.cross_entropy_with_softlabels(input, target, dim=1, reduction='mean')[source]
Parameters
  • input (torch.Tensor) –

  • target (torch.Tensor) –

  • dim (int) –

  • reduction (str) –

Returns

Return type

torch.Tensor

Module contents