Coverage for barbet/modules.py: 11.76%

102 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-12 04:23 +0000

1import gc 

2from torchapp.modules import GeneralLightningModule 

3# import pandas as pd 

4import polars as pl 

5import torch 

6from collections import defaultdict 

7from hierarchicalsoftmax.inference import ( 

8 greedy_lineage_probabilities, 

9 node_probabilities, 

10 greedy_predictions, 

11) 

12from barbet.data import RANKS 

13 

14 

15class BarbetLightningModule(GeneralLightningModule): 

16 def __init__(self, *args, **kwargs): 

17 super().__init__(*args, **kwargs) 

18 

19 def setup_prediction(self, barbet, names:list[str]|str, threshold:float=0.0, save_probabilities:bool=False): 

20 self.names = names 

21 self.classification_tree = self.hparams.classification_tree 

22 # self.logits = defaultdict(lambda: 0.0) 

23 # self.counts = defaultdict(lambda: 0) 

24 self.counter = 0 

25 unique_names = list(set(names)) if isinstance(names, list) else [names] 

26 genome_count = len(unique_names) 

27 self.name_to_index = {name: i for i, name in enumerate(unique_names)} 

28 self.logits = torch.zeros( 

29 (genome_count, self.classification_tree.layer_size), 

30 dtype=torch.float16, 

31 ) 

32 self.counts = torch.zeros( 

33 (genome_count,) , 

34 dtype=torch.int32, 

35 ) 

36 self.category_names = [ 

37 barbet.node_to_str(node) for node in self.classification_tree.node_list_softmax if not node.is_root 

38 ] 

39 self.barbet = barbet 

40 self.threshold = threshold 

41 self.save_probabilities = save_probabilities 

42 

43 def on_predict_batch_end(self, results, batch, batch_idx, dataloader_idx=0): 

44 batch_size = len(results) 

45 if isinstance(self.names, str): 

46 genome_index = self.name_to_index[self.names] 

47 self.counts[genome_index] += batch_size 

48 self.logits[genome_index,:] += results.sum(dim=0).half().cpu() 

49 else: 

50 prev_name = self.names[self.counter] 

51 start_i = 0 

52 for end_i in range(batch_size): 

53 current_name = self.names[self.counter + end_i] 

54 if current_name != prev_name: 

55 genome_index = self.name_to_index[prev_name] 

56 self.counts[genome_index] += (end_i - start_i) 

57 self.logits[genome_index,:] += results[start_i:end_i].sum(dim=0).half().cpu() 

58 start_i = end_i 

59 prev_name = current_name 

60 

61 # Handle the last chunk 

62 assert start_i < batch_size, "Start index should be less than batch size" 

63 genome_index = self.name_to_index[prev_name] 

64 self.logits[genome_index,:] += results[start_i:].sum(dim=0).half().cpu() 

65 self.counts[genome_index] += (batch_size - start_i) 

66 self.counter += batch_size 

67 

68 def on_predict_epoch_end(self): 

69 print("Consolidating results per genome...") 

70 names = list(self.name_to_index.keys()) 

71 self.logits /= self.counts.unsqueeze(1) # Normalize logits by counts 

72 del self.counts 

73 gc.collect() 

74 

75 # Prepare column names and initialize empty lists 

76 output_columns = ['name'] 

77 new_cols = {} 

78 new_cols['name'] = names 

79 

80 for rank in RANKS: 

81 pred_col = f"{rank}_prediction" 

82 prob_col = f"{rank}_probability" 

83 output_columns += [pred_col, prob_col] 

84 new_cols[pred_col] = [] 

85 new_cols[prob_col] = [] 

86 

87 # Convert to probabilities 

88 if self.save_probabilities: 

89 print("Converting to probabilities...") 

90 probabilities = node_probabilities( 

91 self.logits, 

92 root=self.classification_tree, 

93 progress_bar=True, 

94 ) 

95 

96 del self.logits 

97 gc.collect() 

98 

99 print("Saving in dataframe...") 

100 self.results_df = pl.DataFrame( 

101 data=probabilities, 

102 schema=self.category_names 

103 ).with_columns([ 

104 pl.Series("name", names, dtype=pl.Utf8) 

105 ]).with_columns([ 

106 pl.col("name").cast(pl.Utf8) 

107 ]).select(["name", *self.category_names]) 

108 

109 # get greedy predictions which can use the raw activation or the softmax probabilities 

110 print("Getting greedy predictions...") 

111 predictions = greedy_predictions( 

112 probabilities, 

113 root=self.classification_tree, 

114 threshold=self.threshold, 

115 progress_bar=True, 

116 ) 

117 

118 del probabilities 

119 gc.collect() 

120 

121 # Prepare essentials 

122 num_rows = self.results_df.height 

123 

124 for i in range(num_rows): 

125 prediction_node = predictions[i] 

126 lineage = prediction_node.ancestors[1:] + (prediction_node,) 

127 probability = 1.0 

128 

129 for rank, lineage_node in zip(RANKS, lineage): 

130 node_name = self.barbet.node_to_str(lineage_node) 

131 pred_col = f"{rank}_prediction" 

132 prob_col = f"{rank}_probability" 

133 

134 new_cols[pred_col].append(node_name) 

135 

136 if node_name in self.results_df.columns: 

137 probability = self.results_df[node_name][i] 

138 

139 new_cols[prob_col].append(probability) 

140 

141 # Add new columns to the Polars DataFrame 

142 self.results_df = self.results_df.with_columns( 

143 [pl.Series(name, values) for name, values in new_cols.items()] 

144 ) 

145 output_columns += self.category_names 

146 self.results_df = self.results_df[output_columns] 

147 else: 

148 print("Finding greedy predictions...") 

149 results = greedy_lineage_probabilities( 

150 self.logits, 

151 root=self.classification_tree, 

152 threshold=self.threshold, 

153 progress_bar=True, 

154 ) 

155 

156 del self.logits 

157 gc.collect() 

158 

159 for row in results: 

160 if not len(row) == len(RANKS): 

161 breakpoint() 

162 assert len(row) == len(RANKS), f"Row length {len(row)} does not match number of ranks {len(RANKS)}" 

163 for rank_index, (node, probability) in enumerate(row): 

164 rank = RANKS[rank_index] 

165 node_name = self.barbet.node_to_str(node) 

166 pred_col = f"{rank}_prediction" 

167 prob_col = f"{rank}_probability" 

168 

169 new_cols[pred_col].append(node_name) 

170 new_cols[prob_col].append(probability) 

171 

172 # Create the DataFrame 

173 self.results_df = pl.DataFrame( 

174 data=new_cols, 

175 schema=output_columns 

176 ).with_columns([ 

177 pl.col("name").cast(pl.Utf8) 

178 ]).select(output_columns) 

179 

180 

181 

182 

183 

184 

185 

186 

187