Coverage for barbet/modules.py: 11.76%
102 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
1import gc
2from torchapp.modules import GeneralLightningModule
3# import pandas as pd
4import polars as pl
5import torch
6from collections import defaultdict
7from hierarchicalsoftmax.inference import (
8 greedy_lineage_probabilities,
9 node_probabilities,
10 greedy_predictions,
11)
12from barbet.data import RANKS
15class BarbetLightningModule(GeneralLightningModule):
16 def __init__(self, *args, **kwargs):
17 super().__init__(*args, **kwargs)
19 def setup_prediction(self, barbet, names:list[str]|str, threshold:float=0.0, save_probabilities:bool=False):
20 self.names = names
21 self.classification_tree = self.hparams.classification_tree
22 # self.logits = defaultdict(lambda: 0.0)
23 # self.counts = defaultdict(lambda: 0)
24 self.counter = 0
25 unique_names = list(set(names)) if isinstance(names, list) else [names]
26 genome_count = len(unique_names)
27 self.name_to_index = {name: i for i, name in enumerate(unique_names)}
28 self.logits = torch.zeros(
29 (genome_count, self.classification_tree.layer_size),
30 dtype=torch.float16,
31 )
32 self.counts = torch.zeros(
33 (genome_count,) ,
34 dtype=torch.int32,
35 )
36 self.category_names = [
37 barbet.node_to_str(node) for node in self.classification_tree.node_list_softmax if not node.is_root
38 ]
39 self.barbet = barbet
40 self.threshold = threshold
41 self.save_probabilities = save_probabilities
43 def on_predict_batch_end(self, results, batch, batch_idx, dataloader_idx=0):
44 batch_size = len(results)
45 if isinstance(self.names, str):
46 genome_index = self.name_to_index[self.names]
47 self.counts[genome_index] += batch_size
48 self.logits[genome_index,:] += results.sum(dim=0).half().cpu()
49 else:
50 prev_name = self.names[self.counter]
51 start_i = 0
52 for end_i in range(batch_size):
53 current_name = self.names[self.counter + end_i]
54 if current_name != prev_name:
55 genome_index = self.name_to_index[prev_name]
56 self.counts[genome_index] += (end_i - start_i)
57 self.logits[genome_index,:] += results[start_i:end_i].sum(dim=0).half().cpu()
58 start_i = end_i
59 prev_name = current_name
61 # Handle the last chunk
62 assert start_i < batch_size, "Start index should be less than batch size"
63 genome_index = self.name_to_index[prev_name]
64 self.logits[genome_index,:] += results[start_i:].sum(dim=0).half().cpu()
65 self.counts[genome_index] += (batch_size - start_i)
66 self.counter += batch_size
68 def on_predict_epoch_end(self):
69 print("Consolidating results per genome...")
70 names = list(self.name_to_index.keys())
71 self.logits /= self.counts.unsqueeze(1) # Normalize logits by counts
72 del self.counts
73 gc.collect()
75 # Prepare column names and initialize empty lists
76 output_columns = ['name']
77 new_cols = {}
78 new_cols['name'] = names
80 for rank in RANKS:
81 pred_col = f"{rank}_prediction"
82 prob_col = f"{rank}_probability"
83 output_columns += [pred_col, prob_col]
84 new_cols[pred_col] = []
85 new_cols[prob_col] = []
87 # Convert to probabilities
88 if self.save_probabilities:
89 print("Converting to probabilities...")
90 probabilities = node_probabilities(
91 self.logits,
92 root=self.classification_tree,
93 progress_bar=True,
94 )
96 del self.logits
97 gc.collect()
99 print("Saving in dataframe...")
100 self.results_df = pl.DataFrame(
101 data=probabilities,
102 schema=self.category_names
103 ).with_columns([
104 pl.Series("name", names, dtype=pl.Utf8)
105 ]).with_columns([
106 pl.col("name").cast(pl.Utf8)
107 ]).select(["name", *self.category_names])
109 # get greedy predictions which can use the raw activation or the softmax probabilities
110 print("Getting greedy predictions...")
111 predictions = greedy_predictions(
112 probabilities,
113 root=self.classification_tree,
114 threshold=self.threshold,
115 progress_bar=True,
116 )
118 del probabilities
119 gc.collect()
121 # Prepare essentials
122 num_rows = self.results_df.height
124 for i in range(num_rows):
125 prediction_node = predictions[i]
126 lineage = prediction_node.ancestors[1:] + (prediction_node,)
127 probability = 1.0
129 for rank, lineage_node in zip(RANKS, lineage):
130 node_name = self.barbet.node_to_str(lineage_node)
131 pred_col = f"{rank}_prediction"
132 prob_col = f"{rank}_probability"
134 new_cols[pred_col].append(node_name)
136 if node_name in self.results_df.columns:
137 probability = self.results_df[node_name][i]
139 new_cols[prob_col].append(probability)
141 # Add new columns to the Polars DataFrame
142 self.results_df = self.results_df.with_columns(
143 [pl.Series(name, values) for name, values in new_cols.items()]
144 )
145 output_columns += self.category_names
146 self.results_df = self.results_df[output_columns]
147 else:
148 print("Finding greedy predictions...")
149 results = greedy_lineage_probabilities(
150 self.logits,
151 root=self.classification_tree,
152 threshold=self.threshold,
153 progress_bar=True,
154 )
156 del self.logits
157 gc.collect()
159 for row in results:
160 if not len(row) == len(RANKS):
161 breakpoint()
162 assert len(row) == len(RANKS), f"Row length {len(row)} does not match number of ranks {len(RANKS)}"
163 for rank_index, (node, probability) in enumerate(row):
164 rank = RANKS[rank_index]
165 node_name = self.barbet.node_to_str(node)
166 pred_col = f"{rank}_prediction"
167 prob_col = f"{rank}_probability"
169 new_cols[pred_col].append(node_name)
170 new_cols[prob_col].append(probability)
172 # Create the DataFrame
173 self.results_df = pl.DataFrame(
174 data=new_cols,
175 schema=output_columns
176 ).with_columns([
177 pl.col("name").cast(pl.Utf8)
178 ]).select(output_columns)