diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 50a1c84417..c8f7d36eaf 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -4030,6 +4030,23 @@ test_pair_coalescence_counts_missing(void) &ts, 5, missing_ex_nodes, missing_ex_edges, NULL, NULL, NULL, NULL, NULL, 0); verify_pair_coalescence_counts(&ts, 0); verify_pair_coalescence_counts(&ts, TSK_STAT_SPAN_NORMALISE); + verify_pair_coalescence_counts(&ts, TSK_STAT_PAIR_NORMALISE); + verify_pair_coalescence_counts( + &ts, TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE); + tsk_treeseq_free(&ts); +} + +static void +test_pair_coalescence_counts_internal(void) +{ + tsk_treeseq_t ts; + tsk_treeseq_from_text(&ts, 10, internal_sample_ex_nodes, internal_sample_ex_edges, + NULL, NULL, NULL, NULL, NULL, 0); + verify_pair_coalescence_counts(&ts, 0); + verify_pair_coalescence_counts(&ts, TSK_STAT_SPAN_NORMALISE); + verify_pair_coalescence_counts(&ts, TSK_STAT_PAIR_NORMALISE); + verify_pair_coalescence_counts( + &ts, TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE); tsk_treeseq_free(&ts); } @@ -4164,6 +4181,8 @@ main(int argc, char **argv) { "test_pair_coalescence_counts", test_pair_coalescence_counts }, { "test_pair_coalescence_counts_missing", test_pair_coalescence_counts_missing }, + { "test_pair_coalescence_counts_internal", + test_pair_coalescence_counts_internal }, { "test_pair_coalescence_quantiles", test_pair_coalescence_quantiles }, { "test_pair_coalescence_rates", test_pair_coalescence_rates }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1aa06e5b03..94ca49c567 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -9804,6 +9804,23 @@ pair_coalescence_count(tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, } } +static inline void +update_pair_spans(tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_id_t node_set, double *count, double remainder, double *result) +{ + tsk_size_t i; + tsk_id_t j, k, v; + tsk_bug_assert(node_set != TSK_NULL); + for (i = 0; i < num_set_indexes; i++) { + j = set_indexes[2 * i]; + k = set_indexes[2 * i + 1]; + if (node_set == j || node_set == k) { + v = node_set == j ? k : j; + result[i] += count[v] * remainder; + } + } +} + int tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -9814,7 +9831,7 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp { int ret = 0; double left, right, remaining_span, missing_span, window_span, denominator, x, t; - tsk_id_t e, p, c, u, v, w, i, j; + tsk_id_t e, p, c, n, u, v, w, i, j; tsk_size_t num_samples, num_edges; tsk_tree_position_t tree_pos; const tsk_table_collection_t *tables = self->tables; @@ -9822,19 +9839,25 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp const double *restrict nodes_time = self->tables->nodes.time; const double sequence_length = tables->sequence_length; const tsk_size_t num_outputs = summary_func_dim; + const bool span_normalise = options & TSK_STAT_SPAN_NORMALISE; + const bool pair_normalise = options & TSK_STAT_PAIR_NORMALISE; /* buffers */ bool *visited = NULL; tsk_id_t *nodes_sample_set = NULL; tsk_id_t *nodes_parent = NULL; + tsk_id_t *nodes_degree = NULL; + double *total_samples = NULL; double *coalescing_pairs = NULL; double *coalescence_time = NULL; + double *total_pair_spans = NULL; double *nodes_sample = NULL; double *sample_count = NULL; + double *total_pair = NULL; double *bin_weight = NULL; double *bin_values = NULL; + double *bin_totals = NULL; double *pair_count = NULL; - double *total_pair = NULL; double *outside = NULL; /* row pointers */ @@ -9884,18 +9907,23 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp visited = tsk_malloc(num_nodes * sizeof(*visited)); outside = tsk_malloc(num_sample_sets * sizeof(*outside)); nodes_parent = tsk_malloc(num_nodes * sizeof(*nodes_parent)); + nodes_degree = tsk_calloc(num_nodes, sizeof(*nodes_degree)); nodes_sample = tsk_calloc(num_nodes * num_sample_sets, sizeof(*nodes_sample)); sample_count = tsk_malloc(num_nodes * num_sample_sets * sizeof(*sample_count)); coalescing_pairs = tsk_calloc(num_bins * num_set_indexes, sizeof(*coalescing_pairs)); coalescence_time = tsk_calloc(num_bins * num_set_indexes, sizeof(*coalescence_time)); bin_weight = tsk_malloc(num_bins * num_set_indexes * sizeof(*bin_weight)); bin_values = tsk_malloc(num_bins * num_set_indexes * sizeof(*bin_values)); - pair_count = tsk_malloc(num_set_indexes * sizeof(*pair_count)); + bin_totals = tsk_malloc(num_set_indexes * sizeof(*bin_totals)); + total_samples = tsk_calloc(num_sample_sets, sizeof(*total_samples)); + total_pair_spans = tsk_calloc(num_set_indexes, sizeof(*total_pair_spans)); total_pair = tsk_malloc(num_set_indexes * sizeof(*total_pair)); + pair_count = tsk_malloc(num_set_indexes * sizeof(*pair_count)); if (nodes_parent == NULL || nodes_sample == NULL || sample_count == NULL || coalescing_pairs == NULL || bin_weight == NULL || bin_values == NULL - || outside == NULL || pair_count == NULL || visited == NULL - || total_pair == NULL) { + || coalescence_time == NULL || outside == NULL || pair_count == NULL + || visited == NULL || total_samples == NULL || nodes_degree == NULL + || total_pair_spans == NULL || total_pair == NULL || bin_totals == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -9947,7 +9975,8 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp c = tables->edges.child[e]; nodes_parent[c] = TSK_NULL; inside = GET_2D_ROW(sample_count, num_sample_sets, c); - while (p != TSK_NULL) { /* downdate statistic */ + /* downdate numerator */ + while (p != TSK_NULL) { v = node_bin_map[p]; t = nodes_time[p]; if (v != TSK_NULL) { @@ -9968,7 +9997,29 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp p = nodes_parent[c]; } p = tables->edges.parent[e]; - while (p != TSK_NULL) { /* downdate state */ + c = tables->edges.child[e]; + /* downdate denominator */ + for (i = 0; i < 2; i++) { + n = i ? c : p; + v = nodes_sample_set[n]; + nodes_degree[n] -= 1; + if (v != TSK_NULL && nodes_degree[n] == 0) { + above = GET_2D_ROW(sample_count, num_sample_sets, n); + for (j = 0; j < (tsk_id_t) num_sample_sets; j++) { + outside[j] = total_samples[j] - above[j]; + } + update_pair_spans(num_set_indexes, set_indexes, v, outside, + -remaining_span, total_pair_spans); + total_samples[v] -= 1; + } + } + /* downdate state */ + while (p != TSK_NULL) { + v = nodes_sample_set[p]; + if (v != TSK_NULL && nodes_degree[p] > 0) { + update_pair_spans(num_set_indexes, set_indexes, v, inside, + remaining_span, total_pair_spans); + } above = GET_2D_ROW(sample_count, num_sample_sets, p); for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { above[i] -= inside[i]; @@ -9984,15 +10035,37 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp c = tables->edges.child[e]; nodes_parent[c] = p; inside = GET_2D_ROW(sample_count, num_sample_sets, c); - while (p != TSK_NULL) { /* update state */ + /* update state */ + while (p != TSK_NULL) { + v = nodes_sample_set[p]; above = GET_2D_ROW(sample_count, num_sample_sets, p); for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { above[i] += inside[i]; } + if (v != TSK_NULL && nodes_degree[p] > 0) { + update_pair_spans(num_set_indexes, set_indexes, v, inside, + -remaining_span, total_pair_spans); + } p = nodes_parent[p]; } p = tables->edges.parent[e]; - while (p != TSK_NULL) { /* update statistic */ + /* update denominator */ + for (i = 0; i < 2; i++) { + n = i ? c : p; + v = nodes_sample_set[n]; + if (v != TSK_NULL && nodes_degree[n] == 0) { + total_samples[v] += 1; + above = GET_2D_ROW(sample_count, num_sample_sets, n); + for (j = 0; j < (tsk_id_t) num_sample_sets; j++) { + outside[j] = total_samples[j] - above[j]; + } + update_pair_spans(num_set_indexes, set_indexes, v, outside, + remaining_span, total_pair_spans); + } + nodes_degree[n] += 1; + } + /* update numerator */ + while (p != TSK_NULL) { v = node_bin_map[p]; t = nodes_time[p]; if (v != TSK_NULL) { @@ -10023,14 +10096,31 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp while (w < (tsk_id_t) num_windows && windows[w + 1] <= right) { TRANSPOSE_2D(num_bins, num_set_indexes, coalescing_pairs, bin_weight); TRANSPOSE_2D(num_bins, num_set_indexes, coalescence_time, bin_values); + tsk_memcpy( + bin_totals, total_pair_spans, num_set_indexes * sizeof(*bin_totals)); tsk_memset(coalescing_pairs, 0, num_bins * num_set_indexes * sizeof(*coalescing_pairs)); tsk_memset(coalescence_time, 0, num_bins * num_set_indexes * sizeof(*coalescence_time)); + tsk_memset(total_pair_spans, 0, num_set_indexes * sizeof(*total_pair_spans)); remaining_span = sequence_length - windows[w + 1]; for (j = 0; j < (tsk_id_t) num_samples; j++) { /* truncate at tree */ c = sample_sets[j]; p = nodes_parent[c]; + /* split denominator */ + if (nodes_degree[c] > 0) { + v = nodes_sample_set[c]; + above = GET_2D_ROW(sample_count, num_sample_sets, c); + state = GET_2D_ROW(nodes_sample, num_sample_sets, c); + for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { + outside[i] = total_samples[i] + state[i] - 2 * above[i]; + } + update_pair_spans(num_set_indexes, set_indexes, v, outside, + remaining_span / 2, total_pair_spans); + update_pair_spans(num_set_indexes, set_indexes, v, outside, + -remaining_span / 2, bin_totals); + } + /* split numerator */ while (!visited[c] && p != TSK_NULL) { v = node_bin_map[p]; t = nodes_time[p]; @@ -10075,7 +10165,7 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp } } /* normalise weights */ - if (options & (TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE)) { + if (span_normalise || pair_normalise) { window_span = windows[w + 1] - windows[w] - missing_span; missing_span = 0.0; if (num_edges == 0) { @@ -10085,16 +10175,16 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp missing_span += remaining_span; } for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { - denominator = 1.0; - if (options & TSK_STAT_SPAN_NORMALISE) { - denominator *= window_span; - } - if (options & TSK_STAT_PAIR_NORMALISE) { + denominator = bin_totals[i] > 0 ? 1 / bin_totals[i] : 0; + if (span_normalise && !pair_normalise) { denominator *= total_pair[i]; } + if (!span_normalise && pair_normalise) { + denominator *= window_span; + } weight = GET_2D_ROW(bin_weight, num_bins, i); for (v = 0; v < (tsk_id_t) num_bins; v++) { - weight[v] *= denominator == 0.0 ? 0.0 : 1 / denominator; + weight[v] *= denominator; } } } @@ -10117,13 +10207,17 @@ tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_samp tsk_safe_free(nodes_sample_set); tsk_safe_free(coalescing_pairs); tsk_safe_free(coalescence_time); + tsk_safe_free(total_pair_spans); + tsk_safe_free(total_samples); + tsk_safe_free(nodes_degree); tsk_safe_free(nodes_parent); tsk_safe_free(nodes_sample); tsk_safe_free(sample_count); + tsk_safe_free(total_pair); tsk_safe_free(bin_weight); tsk_safe_free(bin_values); + tsk_safe_free(bin_totals); tsk_safe_free(pair_count); - tsk_safe_free(total_pair); tsk_safe_free(visited); tsk_safe_free(outside); return ret; diff --git a/python/tests/test_coalrate.py b/python/tests/test_coalrate.py index 8281864fcd..9d4899c495 100644 --- a/python/tests/test_coalrate.py +++ b/python/tests/test_coalrate.py @@ -57,6 +57,26 @@ def _single_tree_example(L, T): return tables.tree_sequence() +def _remove_partial_ancestry(tables, sample, left, right): + """ + Remove ancestry for a particular sample over [left, right) + """ + for position in [left, right]: + singletons = tables.edges.child == sample + for i in np.flatnonzero(singletons): + if tables.edges.left[i] < position < tables.edges.right[i]: + tables.edges.append(tables.edges[i].replace(left=position)) + tables.edges[i] = tables.edges[i].replace(right=position) + drop = np.logical_and.reduce( + [ + tables.edges.left >= left, + tables.edges.right <= right, + tables.edges.child == sample, + ] + ) + tables.edges.keep_rows(~drop) + + # --- prototype --- # @@ -90,6 +110,18 @@ def _nonmissing_window_span(ts, windows): return window_span +def _sample_set_pairs(ts, sample_sets, indexes): + sample_set_sizes = np.array([len(x) for x in sample_sets]) + total_pairs = np.zeros(len(indexes)) + for i, (j, k) in enumerate(indexes): + total_pairs[i] = ( + sample_set_sizes[j] * sample_set_sizes[k] + if j != k + else sample_set_sizes[j] * (sample_set_sizes[j] - 1) / 2 + ) + return total_pairs + + def _pair_coalescence_weights( coalescing_pairs, nodes_time, @@ -234,6 +266,9 @@ def _pair_coalescence_stat( nodes_map[nodes_oob] = tskit.NULL num_time_windows = time_windows.size - 1 + window_span = _nonmissing_window_span(ts, windows) + total_pairs = _sample_set_pairs(ts, sample_sets, indexes) + num_nodes = ts.num_nodes num_windows = windows.size - 1 num_sample_sets = len(sample_sets) @@ -246,28 +281,45 @@ def _pair_coalescence_stat( output_size = summary_func_dim samples = np.concatenate(sample_sets) + nodes_set = np.full(num_nodes, tskit.NULL) nodes_parent = np.full(num_nodes, tskit.NULL) nodes_sample = np.zeros((num_nodes, num_sample_sets)) nodes_weight = np.zeros((num_time_windows, num_indexes)) nodes_values = np.zeros((num_time_windows, num_indexes)) + nodes_degree = np.zeros(num_nodes, dtype=np.int32) + window_denom = np.zeros(num_indexes) coalescing_pairs = np.zeros((num_time_windows, num_indexes)) coalescence_time = np.zeros((num_time_windows, num_indexes)) + total_pair_spans = np.zeros(num_indexes) + sample_sizes = np.zeros(num_sample_sets) output = np.zeros((num_windows, output_size, num_indexes)) visited = np.full(num_nodes, False) - total_pairs = np.zeros(num_indexes) - sizes = [len(s) for s in sample_sets] - for i, (j, k) in enumerate(indexes): - if j == k: - total_pairs[i] = sizes[j] * (sizes[k] - 1) / 2 - else: - total_pairs[i] = sizes[j] * sizes[k] - - if span_normalise: - window_span = _nonmissing_window_span(ts, windows) + # Jointly update numerator (observed number of coalescing pairs at a node) + # and denominator (the total number of coalescing pairs, integrated over + # nonmissing span). The denominator is complicated by the fact that samples + # may be ancestral to other samples: we do not permit a pair composed of a + # sample and its descendant to coalesce. + + def update_numerator(u, n, m, r, t, x, y): + assert u != tskit.NULL + for i, (j, k) in enumerate(indexes): + w = n[j] * m[k] + if j != k: + w += n[k] * m[j] + x[u, i] += w * r + y[u, i] += w * r * t + + def update_denominator(s, n, r, x): + assert s != tskit.NULL + for i, (j, k) in enumerate(indexes): + if s == j or s == k: + v = k if s == j else j + x[i] += n[v] * r for i, s in enumerate(sample_sets): # initialize nodes_sample[s, i] = 1 + nodes_set[s] = i sample_counts = nodes_sample.copy() w = 0 @@ -284,22 +336,55 @@ def _pair_coalescence_stat( c = edges_child[e] nodes_parent[c] = tskit.NULL inside = sample_counts[c] + # decrement numerator while p != tskit.NULL: - u = nodes_map[p] - t = nodes_time[p] - if u != tskit.NULL: + if nodes_map[p] != tskit.NULL: outside = sample_counts[p] - sample_counts[c] - nodes_sample[p] - for i, (j, k) in enumerate(indexes): - weight = inside[j] * outside[k] - if j != k: - weight += inside[k] * outside[j] - coalescing_pairs[u, i] -= weight * remainder - coalescence_time[u, i] -= weight * remainder * t + update_numerator( + nodes_map[p], + inside, + outside, + -remainder, + nodes_time[p], + coalescing_pairs, + coalescence_time, + ) c, p = p, nodes_parent[p] p = edges_parent[e] + c = edges_child[e] + # decrement denominator at insertion + nodes_degree[c] -= 1 + assert not nodes_degree[c] < 0 + if nodes_set[c] != tskit.NULL and nodes_degree[c] == 0: + update_denominator( + nodes_set[c], + sample_sizes - sample_counts[c], + -remainder, + total_pair_spans, + ) + sample_sizes[nodes_set[c]] -= 1 + nodes_degree[p] -= 1 + assert not nodes_degree[p] < 0 + if nodes_set[p] != tskit.NULL and nodes_degree[p] == 0: + update_denominator( + nodes_set[p], + sample_sizes - sample_counts[p], + -remainder, + total_pair_spans, + ) + sample_sizes[nodes_set[p]] -= 1 + # decrement denominator above insertion and sample counts while p != tskit.NULL: + if nodes_set[p] != tskit.NULL and nodes_degree[p] > 0: + update_denominator( + nodes_set[p], + inside, + remainder, + total_pair_spans, + ) sample_counts[p] -= inside p = nodes_parent[p] + p = edges_parent[e] for b in range(in_range.start, in_range.stop): # edges_in e = in_range.order[b] @@ -307,49 +392,104 @@ def _pair_coalescence_stat( c = edges_child[e] nodes_parent[c] = p inside = sample_counts[c] + # increment sample counts and denominator above insertion while p != tskit.NULL: sample_counts[p] += inside + if nodes_set[p] != tskit.NULL and nodes_degree[p] > 0: + update_denominator( + nodes_set[p], + -inside, + remainder, + total_pair_spans, + ) p = nodes_parent[p] p = edges_parent[e] + # increment denominator at insertion + if nodes_set[c] != tskit.NULL and nodes_degree[c] == 0: + sample_sizes[nodes_set[c]] += 1 + update_denominator( + nodes_set[c], + sample_sizes - sample_counts[c], + remainder, + total_pair_spans, + ) + if nodes_set[p] != tskit.NULL and nodes_degree[p] == 0: + sample_sizes[nodes_set[p]] += 1 + update_denominator( + nodes_set[p], + sample_sizes - sample_counts[p], + remainder, + total_pair_spans, + ) + nodes_degree[c] += 1 + nodes_degree[p] += 1 + # increment numerator while p != tskit.NULL: - u = nodes_map[p] - t = nodes_time[p] - if u != tskit.NULL: + if nodes_map[p] != tskit.NULL: outside = sample_counts[p] - sample_counts[c] - nodes_sample[p] - for i, (j, k) in enumerate(indexes): - weight = inside[j] * outside[k] - if j != k: - weight += inside[k] * outside[j] - coalescing_pairs[u, i] += weight * remainder - coalescence_time[u, i] += weight * remainder * t + update_numerator( + nodes_map[p], + inside, + outside, + remainder, + nodes_time[p], + coalescing_pairs, + coalescence_time, + ) c, p = p, nodes_parent[p] + p = edges_parent[e] + c = edges_child[e] while w < num_windows and windows[w + 1] <= right: # flush window remainder = sequence_length - windows[w + 1] nodes_weight[:] = coalescing_pairs[:] nodes_values[:] = coalescence_time[:] + window_denom[:] = total_pair_spans[:] coalescing_pairs[:] = 0.0 coalescence_time[:] = 0.0 + total_pair_spans[:] = 0.0 for c in samples: + # split denominator + if nodes_degree[c] > 0: + update_denominator( + nodes_set[c], + nodes_sample[c] + sample_sizes - 2 * sample_counts[c], + -remainder / 2, + window_denom, + ) + update_denominator( + nodes_set[c], + nodes_sample[c] + sample_sizes - 2 * sample_counts[c], + remainder / 2, + total_pair_spans, + ) + # split numerator p = nodes_parent[c] while not visited[c] and p != tskit.NULL: - u = nodes_map[p] - t = nodes_time[p] - if u != tskit.NULL: + if nodes_map[p] != tskit.NULL: inside = sample_counts[c] outside = sample_counts[p] - sample_counts[c] - nodes_sample[p] - for i, (j, k) in enumerate(indexes): - weight = inside[j] * outside[k] - if j != k: - weight += inside[k] * outside[j] - x = weight * remainder / 2 - nodes_weight[u, i] -= x - nodes_values[u, i] -= t * x - coalescing_pairs[u, i] += x - coalescence_time[u, i] += t * x + update_numerator( + nodes_map[p], + inside, + outside, + -remainder / 2, + nodes_time[p], + nodes_weight, + nodes_values, + ) + update_numerator( + nodes_map[p], + inside, + outside, + remainder / 2, + nodes_time[p], + coalescing_pairs, + coalescence_time, + ) visited[c] = True p, c = nodes_parent[p], p - for c in samples: + for c in samples: # reset guard p = nodes_parent[c] while visited[c] and p != tskit.NULL: visited[c] = False @@ -358,11 +498,20 @@ def _pair_coalescence_stat( nonzero = nodes_weight[:, i] > 0 nodes_values[nonzero, i] /= nodes_weight[nonzero, i] nodes_values[~nonzero, i] = np.nan - if span_normalise: - nodes_weight /= window_span[w] - if pair_normalise: - nodes_weight /= total_pairs[np.newaxis, :] - for i in range(num_indexes): # apply function to empirical distribution + # normalise weights + if pair_normalise or span_normalise: + nodes_weight = np.divide( + nodes_weight, + window_denom[np.newaxis], + out=np.zeros_like(nodes_weight), + where=window_denom[np.newaxis] > 0, + ) + if span_normalise and not pair_normalise: + nodes_weight *= total_pairs[np.newaxis] + if pair_normalise and not span_normalise: + nodes_weight *= window_span[w] + # apply function to empirical distribution + for i in range(num_indexes): output[w, :, i] = summary_func( nodes_weight[:, i], nodes_values[:, i], @@ -510,7 +659,7 @@ def proto_pair_coalescence_rates( :param list windows: An increasing list of breakpoints between the sequence windows to compute the statistic in, or None. """ - # TODO^^^ + # TODO: update for parity with main method docstring if not (isinstance(time_windows, np.ndarray) and time_windows.size > 1): raise ValueError("Time windows must be an array of breakpoints") @@ -614,6 +763,17 @@ def naive_pair_coalescence_counts(ts, sample_set_0, sample_set_1): given node is the product of the number of samples subtended by the left and right child. For higher arities, the count is summed over all possible pairs of children. + + Two properties worth noting: + + - Ancestor-descendant pairs are not counted at any node: at an + internal sample S, S itself is not in any of S's children's subtrees; + at any ancestor of S, S and its descendants live in the same + child-subtree, so they are never paired via the cross-children sum. + + - Samples detached from the tree over an interval are not enumerated + by t.samples(p) for any internal p, and so contribute zero pair + count over those intervals. """ output = np.zeros(ts.num_nodes) for t in ts.trees(): @@ -772,7 +932,6 @@ def test_total_pairs(self): check = np.array([0.0] * 8 + [1, 2, 1, 5, 4, 15]) implm = ts.pair_coalescence_counts() np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts) np.testing.assert_allclose(proto, check) @@ -800,7 +959,6 @@ def test_population_pairs(self): check[1] = np.array([0.0] * 8 + [0, 0, 0, 0, 4, 12]) check[2] = np.array([0.0] * 8 + [1, 2, 0, 0, 0, 3]) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, sample_sets=[ss0, ss1], indexes=indexes ) @@ -831,10 +989,33 @@ def test_internal_samples(self): implm = ts.pair_coalescence_counts(span_normalise=False) check = np.array([0] * 8 + [1, 2, 1, 5, 5, 24]) * ts.sequence_length np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, span_normalise=False) np.testing.assert_allclose(proto, check) + def test_internal_samples_normalised(self): + """ + See `test_internal_samples`, but checking correctness of + normalisation + """ + ts = self.example_ts() + tables = ts.dump_tables() + nodes_flags = tables.nodes.flags.copy() + nodes_flags[9] = tskit.NODE_IS_SAMPLE + nodes_flags[11] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = nodes_flags + ts = tables.tree_sequence() + assert ts.num_samples == 10 + unnormalised = ts.pair_coalescence_counts(span_normalise=False) + implm = ts.pair_coalescence_counts(span_normalise=True, pair_normalise=True) + check = unnormalised / unnormalised.sum() + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(implm.sum(), 1.0) + proto = proto_pair_coalescence_counts( + ts, span_normalise=True, pair_normalise=True + ) + np.testing.assert_allclose(proto, check) + np.testing.assert_allclose(proto.sum(), 1.0) + def test_windows(self): ts = self.example_ts() check = np.array([0.0] * 8 + [1, 2, 1, 5, 4, 15]) * ts.sequence_length / 2 @@ -843,7 +1024,6 @@ def test_windows(self): ) np.testing.assert_allclose(implm[0], check) np.testing.assert_allclose(implm[1], check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, windows=np.linspace(0, ts.sequence_length, 3), span_normalise=False ) @@ -871,7 +1051,6 @@ def test_time_windows(self): span_normalise=False, time_windows=time_windows ) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, span_normalise=False, time_windows=time_windows ) @@ -894,7 +1073,6 @@ def test_pair_normalise(self): total_pairs = np.array([6, 16, 6]) check /= total_pairs[:, np.newaxis] np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, sample_sets=[ss0, ss1], @@ -910,7 +1088,6 @@ def test_multiple_roots(self): check = np.array([0.0] * 8 + [1, 2, 1, 5, 0, 0, 0, 0]) check /= total_pairs np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, pair_normalise=True) np.testing.assert_allclose(proto, check) @@ -965,7 +1142,6 @@ def test_total_pairs(self): implm = ts.pair_coalescence_counts(span_normalise=False) check = np.array([0] * 4 + [1 * (L - S), 2 * (L - S) + 1 * S, 2 * S, 3 * L]) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, span_normalise=False) np.testing.assert_allclose(proto, check) @@ -994,7 +1170,6 @@ def test_population_pairs(self): check[1] = np.array([0] * 4 + [1 * (L - S), 1 * (L - S), 2 * S, 2 * L]) check[2] = np.array([0] * 4 + [0, 1 * L, 0, 0]) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, sample_sets=[[0, 1], [2, 3]], indexes=indexes, span_normalise=False ) @@ -1023,10 +1198,32 @@ def test_internal_samples(self): implm = ts.pair_coalescence_counts(span_normalise=False) check = np.array([0.0] * 4 + [(L - S), S + 2 * (L - S), 3 * S, 4 * L]) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, span_normalise=False) np.testing.assert_allclose(proto, check) + def test_internal_samples_normalised(self): + """ + See `test_internal_samples`, but checking correctness of + normalisation + """ + L, S = 200, 100 + ts = self.example_ts(S, L) + tables = ts.dump_tables() + nodes_flags = tables.nodes.flags.copy() + nodes_flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = nodes_flags + ts = tables.tree_sequence() + unnormalised = ts.pair_coalescence_counts(span_normalise=False) + implm = ts.pair_coalescence_counts(span_normalise=True, pair_normalise=True) + check = unnormalised / unnormalised.sum() + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(implm.sum(), 1.0) + proto = proto_pair_coalescence_counts( + ts, span_normalise=True, pair_normalise=True + ) + np.testing.assert_allclose(proto, check) + np.testing.assert_allclose(proto.sum(), 1.0) + def test_windows(self): """ ┊ 3 pairs 3 ┊ @@ -1048,7 +1245,6 @@ def test_windows(self): implm = ts.pair_coalescence_counts(windows=windows, span_normalise=False) np.testing.assert_allclose(implm[0], check_0) np.testing.assert_allclose(implm[1], check_1) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, windows=windows, span_normalise=False) np.testing.assert_allclose(proto[0], check_0) np.testing.assert_allclose(proto[1], check_1) @@ -1079,7 +1275,6 @@ def test_time_windows(self): ) np.testing.assert_allclose(implm[0], check_0) np.testing.assert_allclose(implm[1], check_1) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, span_normalise=False, @@ -1106,7 +1301,6 @@ def test_pair_normalise(self): total_pairs = np.array([1, 4, 1]) check /= total_pairs[:, np.newaxis] np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, sample_sets=[[0, 1], [2, 3]], @@ -1124,13 +1318,536 @@ def test_multiple_roots(self): check = np.array([0.0] * 4 + [1 * (L - S), 2 * (L - S) + 1 * S, 0, 0, 0, 0]) check /= total_pairs np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, pair_normalise=True, span_normalise=False ) np.testing.assert_allclose(proto, check) +class TestCoalescingPairsMissingSamples: + """ + Test against a worked example where samples are partially missing + """ + + def example_ts(self): + """ + 3.00┊ 8 ┊ ┊ ┊ ┊ + ┊ ┏━━┻━┓ ┊ ┊ ┊ ┊ + 2.00┊ 7 ┃ ┊ 7 ┊ 9 ┊ ┊ + ┊ ┏━┻━┓ ┃ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊ ┊ + 1.00┊ 5 6 ┃ ┊ 5 ┃ ┊ 6 ┃ ┊ ┊ + ┊ ┏┻┓ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊ ┊ + 0.00┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ + 0 1 2 3 5 + A A B B B A A B B B B + """ + tables = tskit.TableCollection(sequence_length=5) + for _ in range(2): + tables.populations.add_row() + tables.nodes.set_columns( + time=np.array([0, 0, 0, 0, 0, 1, 1, 2, 3, 2]), + flags=np.array([tskit.NODE_IS_SAMPLE] * 5 + [0] * 5, dtype=np.uint32), + population=np.array([0, 0, 1, 1, 1] + [tskit.NULL] * 5, dtype=np.int32), + ) + tables.edges.set_columns( + left=np.array([0, 0, 0, 2, 0, 2, 1, 0, 0, 2, 2, 0, 0]), + right=np.array([2, 2, 1, 3, 1, 3, 2, 2, 1, 3, 3, 1, 1]), + parent=np.array([5, 5, 6, 6, 6, 6, 7, 7, 7, 9, 9, 8, 8], dtype=np.int32), + child=np.array([0, 1, 2, 2, 3, 3, 2, 5, 6, 4, 6, 4, 7], dtype=np.int32), + ) + return tables.tree_sequence() + + def test_total_pairs_unnormalised(self): + """ + With no normalisation the counts should sum to the total number of + span-weighted pairs. + """ + total_pair_span = 16.0 + ts = self.example_ts() + implm = ts.pair_coalescence_counts(span_normalise=False) + proto = proto_pair_coalescence_counts(ts, span_normalise=False) + np.testing.assert_allclose(implm.sum(), total_pair_span) + np.testing.assert_allclose(proto.sum(), total_pair_span) + + def test_pair_normalise_only(self): + """ + With `pair_normalise=True, span_normalise=False` the denominator + should equal `E[pair count] = total_pair_span / nonmissing_sequence` + """ + total_pair_span = 16.0 + nonmissing_sequence = 3.0 + ts = self.example_ts() + unnormalised = ts.pair_coalescence_counts( + span_normalise=False, pair_normalise=False + ) + check = unnormalised * nonmissing_sequence / total_pair_span + implm = ts.pair_coalescence_counts(span_normalise=False, pair_normalise=True) + proto = proto_pair_coalescence_counts( + ts, span_normalise=False, pair_normalise=True + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_span_normalise_only(self): + """ + With `pair_normalise=False, span_normalise=True` the denominator + should equal `E[pair span] = total_pair_span / total_pairs` + """ + total_pair_span = 16.0 + total_pairs = 5 * (5 - 1) / 2 + ts = self.example_ts() + unnormalised = ts.pair_coalescence_counts( + span_normalise=False, pair_normalise=False + ) + check = unnormalised * total_pairs / total_pair_span + implm = ts.pair_coalescence_counts(span_normalise=True, pair_normalise=False) + proto = proto_pair_coalescence_counts( + ts, span_normalise=True, pair_normalise=False + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_pair_and_span_normalise(self): + """ + With `pair_normalise=True, span_normalise=True` the denominator + should equal `total_pair_span` so that the "counts" sum to one + """ + total_pair_span = 16.0 + ts = self.example_ts() + unnormalised = ts.pair_coalescence_counts( + span_normalise=False, pair_normalise=False + ) + check = unnormalised / total_pair_span + implm = ts.pair_coalescence_counts(span_normalise=True, pair_normalise=True) + proto = proto_pair_coalescence_counts( + ts, span_normalise=True, pair_normalise=True + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + np.testing.assert_allclose(implm.sum(), 1.0) + np.testing.assert_allclose(proto.sum(), 1.0) + + def test_normalised_sums_to_unity(self): + """ + With both `pair_normalise=True, span_normalise=True` each window and + index pair should sum to one (over nodes) unless there are multiple + roots + """ + ts = self.example_ts() + ss = [[0, 1], [2, 3, 4]] + idx = [(0, 0), (0, 1), (1, 1)] + windows = ts.breakpoints(as_array=True) + check = np.array( + [ + [1, 1, 1], # tree 1 + [1, 1, 0], # tree 2 + [0, 0, 1], # tree 3 + [0, 0, 0], # tree 4 + ] + ) + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm.sum(axis=-1), check) + np.testing.assert_allclose(proto.sum(axis=-1), check) + + def test_population_pairs(self): + """ + Number of span-weighted pairs across sample sets + """ + ts = self.example_ts() + ss = [[0, 1], [2, 3, 4]] + idx = [(0, 0), (0, 1), (1, 1)] + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + denom = np.array([2.0, 8.0, 6.0]) # (AA, AB, BB) + check = unnormalised / denom[:, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_trees(self): + """ + Per-tree denominator equals the pair count for unit-width trees, and + zero for the empty tree. By convention the "normalised" result returns + zero for the empty tree. + """ + ts = self.example_ts() + windows = ts.breakpoints(as_array=True) + expected_denominator = np.array([10.0, 3.0, 3.0, 0.0])[:, np.newaxis] + unnormalised = ts.pair_coalescence_counts( + windows=windows, span_normalise=False, pair_normalise=False + ) + check = np.zeros_like(unnormalised) + np.divide( + unnormalised, expected_denominator, out=check, where=expected_denominator > 0 + ) + implm = ts.pair_coalescence_counts( + windows=windows, span_normalise=True, pair_normalise=True + ) + proto = proto_pair_coalescence_counts( + ts, windows=windows, span_normalise=True, pair_normalise=True + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_windows(self): + """ + Expanding overlapped trees, window [0, 1.5) has denominator (10 + 3 * + 0.5), whereas window [1.5, 3.5) has denominator (3 * 0.5 + 3 + 0) + """ + ts = self.example_ts() + windows = np.array([0.0, 1.5, 3.5, 5.0]) + unnormalised = ts.pair_coalescence_counts( + windows=windows, span_normalise=False, pair_normalise=False + ) + denom_0 = 10 * 1.0 + 3 * 0.5 + denom_1 = 3 * 0.5 + 3 * 1.0 + 2.0 * 0.0 + check_0 = unnormalised[0] / denom_0 + check_1 = unnormalised[1] / denom_1 + implm = ts.pair_coalescence_counts( + windows=windows, span_normalise=True, pair_normalise=True + ) + proto = proto_pair_coalescence_counts( + ts, windows=windows, span_normalise=True, pair_normalise=True + ) + np.testing.assert_allclose(implm[0], check_0) + np.testing.assert_allclose(implm[1], check_1) + np.testing.assert_allclose(implm[2], 0.0) + np.testing.assert_allclose(proto[0], check_0) + np.testing.assert_allclose(proto[1], check_1) + np.testing.assert_allclose(proto[2], 0.0) + + def test_internal_sample(self): + """ + Flag node 6 (internal, time 1, population B) as a sample. It is present + in two trees, and cannot coalesce with its descendents, so the pair + counts per tree are: + + tree 1: AA=1, BB=4, AB=8 + tree 2: AA=1, BB=0, AB=2 + tree 3: AA=0, BB=4, AB=0 + tree 4: AA=0, BB=0, AB=0 + """ + tab = self.example_ts().dump_tables() + tab.nodes[6] = tab.nodes[6].replace(flags=tskit.NODE_IS_SAMPLE, population=1) + ts = tab.tree_sequence() + ss = [[0, 1], [2, 3, 4, 6]] + idx = [(0, 0), (0, 1), (1, 1)] + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + denominators = np.array([2.0, 10.0, 8.0]) # AA, AB, BB + check = unnormalised / denominators[:, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_internal_sample_windowed(self): + """ + Flag node 6 (internal, time 1, population B) as a sample, use windows; + see docstring for 'test_internal_sample' + """ + tab = self.example_ts().dump_tables() + tab.nodes[6] = tab.nodes[6].replace(flags=tskit.NODE_IS_SAMPLE, population=1) + ts = tab.tree_sequence() + ss = [[0, 1], [2, 3, 4, 6]] + idx = [(0, 0), (0, 1), (1, 1)] + windows = np.array([0.0, 1.5, 3.5, 5.0]) + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=False, + pair_normalise=False, + ) + denom_0 = np.array([1.5, 9.0, 4.0]) # (AA, AB, BB) + denom_1 = np.array([0.5, 1.0, 4.0]) + check_0 = unnormalised[0] / denom_0[:, np.newaxis] + check_1 = unnormalised[1] / denom_1[:, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm[0], check_0) + np.testing.assert_allclose(implm[1], check_1) + np.testing.assert_allclose(implm[2], 0.0) + np.testing.assert_allclose(proto[0], check_0) + np.testing.assert_allclose(proto[1], check_1) + np.testing.assert_allclose(proto[2], 0.0) + + def test_internal_sample_cross_coalescence(self): + """ + See docstring for `test_internal_sample`, but placing ancestor (node + 6) and its descendants (2, 3) in different sample sets + """ + tab = self.example_ts().dump_tables() + tab.nodes[6] = tab.nodes[6].replace(flags=tskit.NODE_IS_SAMPLE) + ts = tab.tree_sequence() + ss = [[0, 1, 6], [2, 3]] + idx = [(0, 0), (0, 1), (1, 1)] + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + denominators = np.array([4.0, 6.0, 2.0]) # AA, AB, BB + check = unnormalised / denominators[:, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_nested_internal_samples(self): + """ + Flag internal node 5 and its ancestor node 7 as samples + """ + tab = self.example_ts().dump_tables() + tab.nodes[5] = tab.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + tab.nodes[7] = tab.nodes[7].replace(flags=tskit.NODE_IS_SAMPLE) + ts = tab.tree_sequence() + ss = [[0, 1, 2, 5, 7]] + idx = [(0, 0)] + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + denominator = 8.0 + check = unnormalised / denominator + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_nested_internal_samples_with_span_normalise(self): + """ + Flag internal node 5 and its ancestor node 7 as samples. Verify + that `span_normalise` alone gives output proportional to the total + number of pairs (not the effective number of pairs). + """ + tab = self.example_ts().dump_tables() + tab.nodes[5] = tab.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + tab.nodes[7] = tab.nodes[7].replace(flags=tskit.NODE_IS_SAMPLE) + ts = tab.tree_sequence() + ss = [[0, 1, 2, 5, 7]] + idx = [(0, 0)] + P_raw = 10 # choose(5, 2) + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + total_pair_span = unnormalised.sum() + check = unnormalised * P_raw / total_pair_span + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=False, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=False, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_mixed_descendants(self): + """ + Flag internal node 7 as a sample and put its descendants in two + different sample sets + """ + tab = self.example_ts().dump_tables() + tab.nodes[7] = tab.nodes[7].replace(flags=tskit.NODE_IS_SAMPLE) + ts = tab.tree_sequence() + ss = [[0, 1, 7], [2, 3]] + idx = [(0, 0), (0, 1), (1, 1)] + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + denominators = np.array([2.0, 6.0, 2.0]) # AA, AB, BB + check = unnormalised / denominators[:, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + def test_internal_sample_window_boundary(self): + """ + Test window flush when an internal sample crosses the window boundary + """ + tab = self.example_ts().dump_tables() + tab.nodes[6] = tab.nodes[6].replace(flags=tskit.NODE_IS_SAMPLE) + ts = tab.tree_sequence() + ss = [[2, 3, 4, 6]] + idx = [(0, 0)] + windows = np.array([0.0, 0.5, 1.5, 5.0]) + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=False, + pair_normalise=False, + ) + denominators = np.array([2.0, 2.0, 4.0]) + check = unnormalised / denominators[:, np.newaxis, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + windows=windows, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(implm.sum(axis=-1), np.ones((3, 1))) + np.testing.assert_allclose(proto, check) + np.testing.assert_allclose(proto.sum(axis=-1), np.ones((3, 1))) + + def test_root_sample(self): + """ + Flag node 9 (root, time 2, population B) as a sample. It cannot + coalesce with its descendents, so the pair counts per tree are: + + tree 1: AA=1, BB=3, AB=6 + tree 2: AA=1, BB=0, AB=2 + tree 3: AA=0, BB=3, AB=0 + tree 4: AA=0, BB=0, AB=0 + """ + tab = self.example_ts().dump_tables() + tab.nodes[9] = tab.nodes[9].replace(flags=tskit.NODE_IS_SAMPLE, population=1) + ts = tab.tree_sequence() + ss = [[0, 1], [2, 3, 4, 9]] + idx = [(0, 0), (0, 1), (1, 1)] + unnormalised = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=False, + pair_normalise=False, + ) + denominators = np.array([2.0, 8.0, 6.0]) # AA, AB, BB + check = unnormalised / denominators[:, np.newaxis] + implm = ts.pair_coalescence_counts( + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=ss, + indexes=idx, + span_normalise=True, + pair_normalise=True, + ) + np.testing.assert_allclose(implm, check) + np.testing.assert_allclose(proto, check) + + class TestCoalescingPairsSimulated: """ Test against a naive implementation on simulated data. @@ -1158,6 +1875,34 @@ def example_ts(self): assert ts.num_trees > 1 return ts + @tests.cached_example + def messy_ts(self): + """ + with only some samples missing over intervals, + and many internal nodes flagged as samples + """ + rng = np.random.default_rng(seed=2048) + ts = self.example_ts() + orig_num_samples = ts.num_samples + orig_segsites = ts.segregating_sites(mode="branch") + tables = ts.dump_tables() + for s in ts.samples(): + a, b = ts.sequence_length * rng.uniform([0.0, 0.6], [0.4, 1.0]) + _remove_partial_ancestry(tables, s, a, b) + nodes_flags = tables.nodes.flags + nodes_population = tables.nodes.population + internal = rng.integers(ts.num_samples, ts.num_nodes, size=30) + population = rng.integers(0, 2, size=30) + nodes_flags[internal] = tskit.NODE_IS_SAMPLE + nodes_population[internal] = population + tables.nodes.flags = nodes_flags + tables.nodes.population = nodes_population + tables.sort() + ts = tables.tree_sequence() + assert ts.num_samples > orig_num_samples + assert ts.segregating_sites(mode="branch") < orig_segsites + return ts + @staticmethod def _check_total_pairs(ts, windows): samples = list(ts.samples()) @@ -1168,14 +1913,13 @@ def _check_total_pairs(ts, windows): tsw = ts.keep_intervals(np.array([[a, b]]), simplify=False) check[w] = naive_pair_coalescence_counts(tsw, samples, samples) / 2 np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, windows=windows, span_normalise=False) np.testing.assert_allclose(proto, check) @staticmethod def _check_subset_pairs(ts, windows): - ss0 = np.flatnonzero(ts.nodes_population == 0) - ss1 = np.flatnonzero(ts.nodes_population == 1) + ss0 = list(ts.samples(population=0)) + ss1 = list(ts.samples(population=1)) idx = [(0, 1), (1, 1), (0, 0)] implm = ts.pair_coalescence_counts( sample_sets=[ss0, ss1], indexes=idx, windows=windows, span_normalise=False @@ -1188,7 +1932,6 @@ def _check_subset_pairs(ts, windows): check[w, 1] = naive_pair_coalescence_counts(tsw, ss1, ss1) / 2 check[w, 2] = naive_pair_coalescence_counts(tsw, ss0, ss0) / 2 np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, sample_sets=[ss0, ss1], @@ -1289,7 +2032,6 @@ def test_span_normalise(self): implm = ts.pair_coalescence_counts(windows=windows, span_normalise=False) check = ts.pair_coalescence_counts(windows=windows) * window_size[:, np.newaxis] np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, windows=windows, span_normalise=False) np.testing.assert_allclose(proto, check) @@ -1308,7 +2050,6 @@ def test_span_normalise_with_missing_flanks(self): ) implm = ts.pair_coalescence_counts(windows=windows, span_normalise=True) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, windows=windows, span_normalise=True) np.testing.assert_allclose(proto, check) @@ -1336,6 +2077,45 @@ def test_span_normalise_with_missing_interior(self): ).flatten() np.testing.assert_array_almost_equal(proto, check) + def test_messy(self): + """ + test when only some samples are missing and there are + internal samples + """ + ts = self.messy_ts() + windows = np.linspace(0, ts.sequence_length, 5) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_normalised_sums_to_unity(self): + """ + test that fully normalised counts sum to one across all nodes, + even with missingness and internal samples + """ + ts = self.messy_ts() + sample_sets = [ + list(ts.samples(population=0)), + list(ts.samples(population=1)), + ] + windows = np.linspace(0, ts.sequence_length, 5) + implm = ts.pair_coalescence_counts( + sample_sets=sample_sets, + indexes=[(0, 0), (0, 1), (1, 1)], + windows=windows, + pair_normalise=True, + span_normalise=True, + ) + proto = proto_pair_coalescence_counts( + ts, + sample_sets=sample_sets, + indexes=[(0, 0), (0, 1), (1, 1)], + windows=windows, + pair_normalise=True, + span_normalise=True, + ) + np.testing.assert_allclose(implm.sum(axis=-1), 1.0) + np.testing.assert_allclose(proto.sum(axis=-1), 1.0) + def test_empty_windows(self): """ test that windows without nodes contain zeros @@ -1362,7 +2142,6 @@ def test_pair_normalise(self): check = ts.pair_coalescence_counts(windows=windows) * window_size[:, np.newaxis] check /= total_pairs np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, windows=windows, span_normalise=False, pair_normalise=True ) @@ -1401,7 +2180,6 @@ def test_time_windows(self): nodes_map = np.searchsorted(time_windows, ts.nodes_time, side="right") - 1 check = np.bincount(nodes_map, weights=check) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, span_normalise=False, time_windows=time_windows ) @@ -1430,7 +2208,6 @@ def test_time_windows_truncated(self): oob = np.logical_or(nodes_map < 0, nodes_map >= time_windows.size) check = np.bincount(nodes_map[~oob], weights=check[~oob]) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, span_normalise=False, time_windows=time_windows ) @@ -1454,7 +2231,6 @@ def test_time_windows_unique(self): nodes_map = np.searchsorted(time_windows, ts.nodes_time, side="right") - 1 check = np.bincount(nodes_map, weights=check) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, span_normalise=False, time_windows=time_windows ) @@ -1471,7 +2247,6 @@ def test_diversity(self): implm = ts.pair_coalescence_counts(windows=windows) implm = 2 * (implm @ ts.nodes_time) / implm.sum(axis=1) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts(ts, windows=windows) proto = 2 * (proto @ ts.nodes_time) / proto.sum(axis=1) np.testing.assert_allclose(proto, check) @@ -1488,7 +2263,6 @@ def test_divergence(self): implm = ts.pair_coalescence_counts(sample_sets=[ss0, ss1], windows=windows) implm = 2 * (implm @ ts.nodes_time) / implm.sum(axis=1) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_counts( ts, sample_sets=[ss0, ss1], windows=windows ) @@ -1645,7 +2419,6 @@ def test_quantiles(self): check = _numpy_weighted_quantile(ts.nodes_time, weights, quantiles) implm = ts.pair_coalescence_quantiles(quantiles) np.testing.assert_allclose(implm, check) - # TODO: remove with prototype proto = proto_pair_coalescence_quantiles(ts, quantiles=quantiles) np.testing.assert_allclose(proto, check) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index b0396d6caa..9d10470f3b 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10630,7 +10630,10 @@ def pair_coalescence_counts( ): """ Calculate the number of coalescing sample pairs per node, summed over - trees and weighted by tree span. + trees and weighted by tree span. Precisely, if `c_i(x)` is the + number of pairs that coalesce in node `i` at position `x` (which may be + zero), then the output for the `i`th node is the integral of `c_i` over + the tree sequence (or the intervals defined by `windows`). The number of coalescing pairs may be calculated within or between the non-overlapping lists of samples contained in `sample_sets`. In the @@ -10648,6 +10651,21 @@ def pair_coalescence_counts( events within time intervals (if an array of breakpoints is supplied) rather than for individual nodes (the default). + The flags `span_normalise` and `pair_normalise` control the units of + the output. Let `p(x)` be the number of pairs that could potentially + coalesce at position `x` (omitting "isolated" samples). If both + `span_normalise` and `pair_normalise` are true, then the output is + divided by the integral of `p(x)` over the sequence, will thus sum to + one over the "time" dimension (provided all non-isolated samples trace + back to a single common ancestor per tree). If only `span_normalise` is + set, then the output is divided by `(integral p(x) dx) / (num_samples * + (num_samples - 1) / 2)` (the average non-missing sequence length per + sample pair). Similarly, if only `pair_normalise` is set, the output + is divided by `(integral p(x) dx) / (nonmissing sequence)` ( the + average number of sample pairs per base). The default is + `span_normalise` so that the units of the output are "number of sample + pairs". + The output array has dimension `(windows, indexes, nodes)` with dimensions dropped when the corresponding argument is set to None.