Coverage for barbet/data.py: 63.98%

161 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-12 04:23 +0000

1import random 

2import os 

3import numpy as np 

4from collections import defaultdict 

5from typing import Iterable 

6import torch 

7from torch.utils.data import DataLoader 

8from torch.utils.data import Dataset 

9import lightning as L 

10from dataclasses import dataclass, field 

11from hierarchicalsoftmax import TreeDict 

12 

13 

14RANKS = ["phylum", "class", "order", "family", "genus", "species"] 

15 

16def read_memmap(path, count, dtype:str="float16") -> np.memmap: 

17 file_size = os.path.getsize(path) 

18 dtype_size = np.dtype(dtype).itemsize 

19 num_elements = file_size // dtype_size 

20 embedding_size = num_elements // count 

21 shape = (count, embedding_size) 

22 return np.memmap(path, dtype=dtype, mode='r', shape=shape) 

23 

24 

25def gene_id_from_accession(accession:str): 

26 return accession.split("/")[-1] 

27 

28 

29def choose_k_from_n(lst, k) -> list[int]: 

30 n = len(lst) 

31 if n == 0: 

32 return [] 

33 repetitions = k // n 

34 remainder = k % n 

35 result = lst * repetitions + random.sample(lst, remainder) 

36 return result 

37 

38 

39@dataclass(kw_only=True) 

40class BarbetStack(): 

41 genome:str 

42 array_indices:np.array 

43 

44 def __post_init__(self): 

45 assert self.array_indices.ndim == 1, "Stack indices must be a 1D array" 

46 

47 

48@dataclass(kw_only=True) 

49class BarbetPredictionDataset(Dataset): 

50 array:np.memmap|np.ndarray 

51 accessions: list[str] 

52 stack_size:int 

53 repeats:int = 2 

54 seed:int = 42 

55 stacks: list[BarbetStack] = field(init=False) 

56 genome_filter:set[str]|None = None 

57 

58 def __post_init__(self): 

59 genome_to_array_indices = defaultdict(set) 

60 for index, accession in enumerate(self.accessions): 

61 slash_position = accession.rfind("/") 

62 assert slash_position != -1 

63 genome = accession[:slash_position] 

64 if self.genome_filter and genome not in self.genome_filter: 

65 continue 

66 genome_to_array_indices[genome].add(index) 

67 

68 # Build stacks 

69 random.seed(self.seed) 

70 self.stacks = [] 

71 for genome, genome_array_indices in genome_to_array_indices.items(): 

72 stack_indices = [] 

73 remainder = [] 

74 for repeat_index in range(self.repeats + 1): 

75 if len(remainder) == 0 and repeat_index >= self.repeats: 

76 break 

77 

78 # Finish Remainder 

79 genome_array_indices_set = set(genome_array_indices) 

80 available = genome_array_indices_set - set(remainder) 

81 needed = self.stack_size - len(remainder) 

82 available_list = list(available) 

83 

84 if len(available_list) >= needed: 

85 to_add = random.sample(available_list, needed) # without replacement 

86 else: 

87 to_add = random.choices(available_list, k=needed) # with replacement 

88 to_add_set = set(to_add) 

89 assert not set(remainder) & to_add_set, "remainder and to_add should be disjoint" 

90 

91 self.add_stack(genome, remainder + to_add) 

92 remainder = list(genome_array_indices_set - to_add_set) 

93 random.shuffle(remainder) 

94 

95 # If we have already added each item the required number of times, then stop 

96 if repeat_index >= self.repeats: 

97 break 

98 

99 while len(remainder) >= self.stack_size: 

100 self.add_stack(genome, remainder[:self.stack_size]) 

101 remainder = remainder[self.stack_size:] 

102 

103 def add_stack(self, genome:str, indices:Iterable[int]) -> BarbetStack: 

104 """ 

105 Add a new stack to the dataset. 

106 """ 

107 indices = np.array(sorted(indices)) 

108 stack = BarbetStack(genome=genome, array_indices=indices) 

109 self.stacks.append(stack) 

110 return stack 

111 

112 def __len__(self): 

113 return len(self.stacks) 

114 

115 def __getitem__(self, idx): 

116 stack = self.stacks[idx] 

117 array_indices = stack.array_indices 

118 

119 assert len(array_indices) > 0, f"Stack has no array indices" 

120 with torch.no_grad(): 

121 data = np.asarray(self.array[array_indices, :]).copy() 

122 embeddings = torch.from_numpy(data).to(torch.float16) 

123 

124 del data 

125 

126 return embeddings 

127 

128 

129@dataclass(kw_only=True) 

130class BarbetTrainingDataset(Dataset): 

131 accessions: list[str] 

132 treedict: TreeDict 

133 array:np.memmap|np.ndarray 

134 accession_to_array_index:dict[str,int]|None=None 

135 stack_size:int = 0 

136 

137 def __len__(self): 

138 return len(self.accessions) 

139 

140 def __getitem__(self, idx): 

141 accession = self.accessions[idx] 

142 array_indices = self.accession_to_array_index[accession] if self.accession_to_array_index else idx 

143 if self.stack_size: 

144 array_indices = choose_k_from_n(array_indices, self.stack_size) 

145 

146 assert len(array_indices) > 0, f"Accession {accession} has no array indices" 

147 with torch.no_grad(): 

148 data = np.array(self.array[array_indices, :], copy=False) 

149 embedding = torch.tensor(data, dtype=torch.float16) 

150 del data 

151 

152 # gene_id = gene_id_from_accession(accession) 

153 seq_detail = self.treedict[accession] 

154 node_id = int(seq_detail.node_id) 

155 del seq_detail 

156 

157 return embedding, node_id 

158 

159 

160@dataclass 

161class BarbetDataModule(L.LightningDataModule): 

162 treedict: TreeDict 

163 # seqbank: SeqBank 

164 array:np.memmap|np.ndarray 

165 accession_to_array_index:dict[str,int] 

166 max_items: int = 0 

167 batch_size: int = 16 

168 num_workers: int = 0 

169 validation_partition:int = 0 

170 test_partition:int = -1 

171 train_all:bool = False 

172 

173 def __init__( 

174 self, 

175 treedict: TreeDict, 

176 array:np.memmap|np.ndarray, 

177 accession_to_array_index:dict[str,list[int]], 

178 max_items: int = 0, 

179 batch_size: int = 16, 

180 num_workers: int = None, 

181 validation_partition:int = 0, 

182 test_partition:int=-1, 

183 stack_size:int=0, 

184 train_all:bool=False, 

185 ): 

186 super().__init__() 

187 self.array = array 

188 self.accession_to_array_index = accession_to_array_index 

189 self.treedict = treedict 

190 self.max_items = max_items 

191 self.batch_size = batch_size 

192 self.validation_partition = validation_partition 

193 self.test_partition = test_partition 

194 self.num_workers = min(os.cpu_count(), 8) if num_workers is None else num_workers 

195 self.stack_size = stack_size 

196 self.train_all = train_all 

197 

198 def setup(self, stage=None): 

199 # make assignments here (val/train/test split) 

200 # called on every process in DDP 

201 self.training = [] 

202 self.validation = [] 

203 

204 for accession, details in self.treedict.items(): 

205 partition = details.partition 

206 if partition == self.test_partition: 

207 continue 

208 

209 dataset = self.validation if partition == self.validation_partition else self.training 

210 dataset.append( accession ) 

211 

212 if self.max_items and len(self.training) >= self.max_items and len(self.validation) > 0: 

213 break 

214 

215 if self.train_all: 

216 self.training += self.validation 

217 

218 self.train_dataset = self.create_dataset(self.training) 

219 self.val_dataset = self.create_dataset(self.validation) 

220 

221 def create_dataset(self, accessions:list[str]) -> BarbetTrainingDataset: 

222 return BarbetTrainingDataset( 

223 accessions=accessions, 

224 treedict=self.treedict, 

225 array=self.array, 

226 accession_to_array_index=self.accession_to_array_index, 

227 stack_size=self.stack_size, 

228 ) 

229 

230 def train_dataloader(self): 

231 return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) 

232 

233 def val_dataloader(self): 

234 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 

235