Skip to content

Dot Product

dot_product_score(spot_colors, bled_codes, norm_shift=0, weight_squared=None)

Computes sum(W**2(s * b) / W**2) where s is a spot_color, b is a bled_code and W**2 is weight_squared for a particular spot_color. Sum is over all rounds and channels.

Parameters:

Name Type Description Default
spot_colors np.ndarray

float [n_spots x (n_rounds x n_channels)]. Spot colors normalised to equalise intensities between channels (and rounds).

required
bled_codes np.ndarray

float [n_genes x (n_rounds x n_channels)]. bled_codes such that spot_color of a gene g in round r is expected to be a constant multiple of bled_codes[g, r].

required
norm_shift float

shift to apply to normalisation of spot_colors to limit boost of weak spots.

0
weight_squared Optional[np.ndarray]

float [n_spots x (n_rounds x n_channels)]. squared weight to apply to each round/channel for each spot when computing dot product. If None, all rounds, channels treated equally.

None

Returns:

Type Description
np.ndarray

float [n_spots x n_genes]. score such that score[d, c] gives dot product between spot_colors vector d with bled_codes vector c.

Source code in coppafish/call_spots/dot_product.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def dot_product_score(spot_colors: np.ndarray, bled_codes: np.ndarray, norm_shift: float = 0,
                      weight_squared: Optional[np.ndarray] = None) -> np.ndarray:
    """
    Computes `sum(W**2(s * b) / W**2)` where `s` is a `spot_color`, `b` is a `bled_code` and `W**2` is weight_squared
    for a particular `spot_color`. Sum is over all rounds and channels.

    Args:
        spot_colors: `float [n_spots x (n_rounds x n_channels)]`.
            Spot colors normalised to equalise intensities between channels (and rounds).
        bled_codes: `float [n_genes x (n_rounds x n_channels)]`.
            `bled_codes` such that `spot_color` of a gene `g`
            in round `r` is expected to be a constant multiple of `bled_codes[g, r]`.
        norm_shift: shift to apply to normalisation of spot_colors to limit boost of weak spots.
        weight_squared: `float [n_spots x (n_rounds x n_channels)]`.
            squared weight to apply to each round/channel for each spot when computing dot product.
            If `None`, all rounds, channels treated equally.

    Returns:
        `float [n_spots x n_genes]`.
            `score` such that `score[d, c]` gives dot product between `spot_colors` vector `d`
            with `bled_codes` vector `c`.
    """
    n_spots = spot_colors.shape[0]
    n_genes, n_round_channels = bled_codes.shape
    if not utils.errors.check_shape(spot_colors[0], bled_codes[0].shape):
        raise utils.errors.ShapeError('spot_colors', spot_colors.shape,
                                      (n_spots, n_round_channels))
    spot_norm_factor = np.linalg.norm(spot_colors, axis=1, keepdims=True)
    spot_norm_factor = spot_norm_factor + norm_shift
    spot_colors = spot_colors / spot_norm_factor

    gene_norm_factor = np.linalg.norm(bled_codes, axis=1, keepdims=True)
    gene_norm_factor[gene_norm_factor == 0] = 1  # so don't blow up if bled_code is all 0 for a gene.
    bled_codes = bled_codes / gene_norm_factor

    if weight_squared is not None:
        if not utils.errors.check_shape(weight_squared, spot_colors.shape):
            raise utils.errors.ShapeError('weight', weight_squared.shape,
                                          spot_colors.shape)
        spot_colors = spot_colors * weight_squared

    score = spot_colors @ bled_codes.transpose()

    if weight_squared is not None:
        score = score / np.expand_dims(np.sum(weight_squared, axis=1), 1)
        score = score * n_round_channels  # make maximum score 1 if all weight the same and dot product perfect.

    return score

dot_product_score_no_weight(spot_colors, bled_codes, norm_shift=0)

Computes sum((s * b)) where s is a spot_color, b is a bled_code. Sum is over all rounds and channels.

Parameters:

Name Type Description Default
spot_colors np.ndarray

float [n_spots x (n_rounds x n_channels)]. Spot colors normalised to equalise intensities between channels (and rounds).

required
bled_codes np.ndarray

float [n_genes x (n_rounds x n_channels)]. bled_codes such that spot_color of a gene g in round r is expected to be a constant multiple of bled_codes[g, r].

required
norm_shift float

shift to apply to normalisation of spot_colors to limit boost of weak spots.

0

Returns:

Type Description
np.ndarray

float [n_spots x n_genes]. score such that score[d, c] gives dot product between spot_colors vector d with bled_codes vector c.

Source code in coppafish/call_spots/dot_product.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def dot_product_score_no_weight(spot_colors: np.ndarray, bled_codes: np.ndarray, norm_shift: float = 0) -> np.ndarray:
    """
    Computes `sum((s * b))` where `s` is a `spot_color`, `b` is a `bled_code`.
    Sum is over all rounds and channels.

    Args:
        spot_colors: `float [n_spots x (n_rounds x n_channels)]`.
            Spot colors normalised to equalise intensities between channels (and rounds).
        bled_codes: `float [n_genes x (n_rounds x n_channels)]`.
            `bled_codes` such that `spot_color` of a gene `g`
            in round `r` is expected to be a constant multiple of `bled_codes[g, r]`.
        norm_shift: shift to apply to normalisation of spot_colors to limit boost of weak spots.

    Returns:
        `float [n_spots x n_genes]`.
            `score` such that `score[d, c]` gives dot product between `spot_colors` vector `d`
            with `bled_codes` vector `c`.
    """
    n_spots = spot_colors.shape[0]
    n_genes, n_round_channels = bled_codes.shape
    if not utils.errors.check_shape(spot_colors[0], bled_codes[0].shape):
        raise utils.errors.ShapeError('spot_colors', spot_colors.shape,
                                      (n_spots, n_round_channels))
    spot_norm_factor = np.linalg.norm(spot_colors, axis=1, keepdims=True)
    spot_norm_factor = spot_norm_factor + norm_shift
    spot_colors = spot_colors / spot_norm_factor

    gene_norm_factor = np.linalg.norm(bled_codes, axis=1, keepdims=True)
    gene_norm_factor[gene_norm_factor == 0] = 1  # so don't blow up if bled_code is all 0 for a gene.
    bled_codes = bled_codes / gene_norm_factor

    score = spot_colors @ bled_codes.transpose()
    return score

Optimised

dot_product_score(spot_colors, bled_codes, norm_shift, weight_squared)

Computes sum(W**2(s * b) / W**2) where s is a spot_color, b is a bled_code and W**2 is weight_squared for a particular spot_color. Sum is over all rounds and channels.

Parameters:

Name Type Description Default
spot_colors jnp.ndarray

float [n_spots x (n_rounds x n_channels)]. Spot colors normalised to equalise intensities between channels (and rounds).

required
bled_codes jnp.ndarray

float [n_genes x (n_rounds x n_channels)]. bled_codes such that spot_color of a gene g in round r is expected to be a constant multiple of bled_codes[g, r].

required
norm_shift float

shift to apply to normalisation of spot_colors to limit boost of weak spots.

required
weight_squared jnp.ndarray

float [n_spots x (n_rounds x n_channels)]. squared weight to apply to each round/channel for each spot when computing dot product.

required

Returns:

Type Description
jnp.ndarray

float [n_spots x n_genes]. score such that score[d, c] gives dot product between spot_colors vector d with bled_codes vector c.

Source code in coppafish/call_spots/dot_product_optimised.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@partial(jax.jit, static_argnums=2)
def dot_product_score(spot_colors: jnp.ndarray, bled_codes: jnp.ndarray, norm_shift: float,
                      weight_squared: jnp.ndarray) -> jnp.ndarray:
    """
    Computes `sum(W**2(s * b) / W**2)` where `s` is a `spot_color`, `b` is a `bled_code` and `W**2` is weight_squared
    for a particular `spot_color`. Sum is over all rounds and channels.

    Args:
        spot_colors: `float [n_spots x (n_rounds x n_channels)]`.
            Spot colors normalised to equalise intensities between channels (and rounds).
        bled_codes: `float [n_genes x (n_rounds x n_channels)]`.
            `bled_codes` such that `spot_color` of a gene `g`
            in round `r` is expected to be a constant multiple of `bled_codes[g, r]`.
        norm_shift: shift to apply to normalisation of spot_colors to limit boost of weak spots.
        weight_squared: `float [n_spots x (n_rounds x n_channels)]`.
            squared weight to apply to each round/channel for each spot when computing dot product.

    Returns:
        `float [n_spots x n_genes]`.
            `score` such that `score[d, c]` gives dot product between `spot_colors` vector `d`
            with `bled_codes` vector `c`.
    """
    score = jax.vmap(dot_product_score_single, in_axes=(0, None, None, 0), out_axes=0)(spot_colors, bled_codes,
                                                                                       norm_shift, weight_squared)
    return score

dot_product_score_no_weight(spot_colors, bled_codes, norm_shift)

Computes sum((s * b)) where s is a spot_color, b is a bled_code. Sum is over all rounds and channels.

Parameters:

Name Type Description Default
spot_colors jnp.ndarray

float [n_spots x (n_rounds x n_channels)]. Spot colors normalised to equalise intensities between channels (and rounds).

required
bled_codes jnp.ndarray

float [n_genes x (n_rounds x n_channels)]. bled_codes such that spot_color of a gene g in round r is expected to be a constant multiple of bled_codes[g, r].

required
norm_shift float

shift to apply to normalisation of spot_colors to limit boost of weak spots.

required

Returns:

Type Description
jnp.ndarray

float [n_spots x n_genes]. score such that score[d, c] gives dot product between spot_colors vector d with bled_codes vector c.

Source code in coppafish/call_spots/dot_product_optimised.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@partial(jax.jit, static_argnums=2)
def dot_product_score_no_weight(spot_colors: jnp.ndarray, bled_codes: jnp.ndarray, norm_shift: float) -> jnp.ndarray:
    """
    Computes `sum((s * b))` where `s` is a `spot_color`, `b` is a `bled_code`.
    Sum is over all rounds and channels.

    Args:
        spot_colors: `float [n_spots x (n_rounds x n_channels)]`.
            Spot colors normalised to equalise intensities between channels (and rounds).
        bled_codes: `float [n_genes x (n_rounds x n_channels)]`.
            `bled_codes` such that `spot_color` of a gene `g`
            in round `r` is expected to be a constant multiple of `bled_codes[g, r]`.
        norm_shift: shift to apply to normalisation of spot_colors to limit boost of weak spots.

    Returns:
        `float [n_spots x n_genes]`.
            `score` such that `score[d, c]` gives dot product between `spot_colors` vector `d`
            with `bled_codes` vector `c`.
    """
    score = jax.vmap(dot_product_score_no_weight_single, in_axes=(0, None, None), out_axes=0)(spot_colors, bled_codes,
                                                                                              norm_shift)
    return score