Coverage for barbet/embedding.py: 31.99%

322 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-12 04:23 +0000

1import gzip 

2import os 

3from pathlib import Path 

4from abc import ABC, abstractmethod 

5from Bio import SeqIO 

6import random 

7import numpy as np 

8from rich.progress import track 

9from hierarchicalsoftmax import SoftmaxNode 

10from hierarchicalsoftmax import TreeDict 

11import tarfile 

12import torch 

13from io import StringIO 

14from torchapp.cli import CLIApp, tool, method 

15import typer 

16from dataclasses import dataclass 

17 

18from .data import read_memmap, RANKS 

19 

20 

21def _open(path, mode='rt', **kwargs): 

22 """ 

23 Open a file normally, or with gzip if it ends in .gz. 

24  

25 Args: 

26 path (str or Path): The path to the file. 

27 mode (str): The mode to open the file with (default 'rt' for reading text). 

28 **kwargs: Additional arguments passed to open or gzip.open. 

29 

30 Returns: 

31 A file object. 

32 """ 

33 path = Path(path) 

34 if path.suffix == '.gz': 

35 return gzip.open(path, mode, **kwargs) 

36 return open(path, mode, **kwargs) 

37 

38 

39def set_validation_rank_to_treedict( 

40 treedict:TreeDict, 

41 validation_rank:str="species", 

42 partitions:int=5, 

43) -> TreeDict: 

44 # find the taxonomic rank to use for the validation partition 

45 validation_rank = validation_rank.lower() 

46 assert validation_rank in RANKS 

47 validation_rank_index = RANKS.index(validation_rank) 

48 

49 partitions_dict = {} 

50 for key in treedict: 

51 node = treedict.node(key) 

52 # Assign validation partition at set rank 

53 partition_node = node.ancestors[validation_rank_index] 

54 if partition_node not in partitions_dict: 

55 partitions_dict[partition_node] = random.randint(0,partitions-1) 

56 

57 treedict[key].partition = partitions_dict[partition_node] 

58 

59 return treedict 

60 

61 

62def get_key(accession:str, gene:str) -> str: 

63 """ Returns the standard format of a key """ 

64 key = f"{accession}/{gene}" 

65 return key 

66 

67 

68def get_node(lineage:str, lineage_to_node:dict[str,SoftmaxNode]) -> SoftmaxNode: 

69 if lineage in lineage_to_node: 

70 return lineage_to_node[lineage] 

71 

72 assert ";" in lineage, f"Semi-colon ';' not found in lineage '{lineage}'" 

73 split_point = lineage.rfind(";") 

74 parent_lineage = lineage[:split_point] 

75 name = lineage[split_point+1:] 

76 parent = get_node(parent_lineage, lineage_to_node) 

77 node = SoftmaxNode(name=name, parent=parent) 

78 lineage_to_node[lineage] = node 

79 return node 

80 

81 

82def generate_overlapping_intervals(total: int, interval_size: int, min_overlap: int, check:bool=True, variable_size:bool=False) -> list[tuple[int, int]]: 

83 """ 

84 Creates a list of overlapping intervals within a specified range, adjusting the interval size to ensure 

85 that the overlap is approximately the same across all intervals. 

86 

87 Args: 

88 total (int): The total range within which intervals are to be created. 

89 max_interval_size (int): The maximum size of each interval. 

90 min_overlap (int): The minimum number of units by which consecutive intervals overlap. 

91 check (bool): If True, checks are performed to ensure that the intervals meet the specified conditions. 

92 

93 Returns: 

94 list[tuple[int, int]]: A list of tuples where each tuple represents the start (inclusive)  

95 and end (exclusive) of an interval. 

96 

97 Example: 

98 >>> generate_overlapping_intervals(20, 5, 2) 

99 [(0, 5), (3, 8), (6, 11), (9, 14), (12, 17), (15, 20)] 

100 """ 

101 intervals = [] 

102 start = 0 

103 

104 if total == 0: 

105 return intervals 

106 

107 max_interval_size = interval_size 

108 assert interval_size 

109 assert min_overlap is not None 

110 assert interval_size > min_overlap, f"Max interval size of {interval_size} must be greater than min overlap of {min_overlap}" 

111 

112 # Calculate the number of intervals needed to cover the range 

113 num_intervals, remainder = divmod(total - min_overlap, interval_size - min_overlap) 

114 if remainder > 0: 

115 num_intervals += 1 

116 

117 # Calculate the exact interval size to ensure consistent overlap 

118 overlap = min_overlap 

119 if variable_size: 

120 if num_intervals > 1: 

121 interval_size, remainder = divmod(total + (num_intervals - 1) * overlap, num_intervals) 

122 if remainder > 0: 

123 interval_size += 1 

124 else: 

125 # If the size is fixed, then vary the overlap to keep it even 

126 if num_intervals > 1: 

127 overlap, remainder = divmod( num_intervals * interval_size - total, num_intervals - 1) 

128 if overlap < min_overlap: 

129 overlap = min_overlap 

130 

131 while True: 

132 end = start + interval_size 

133 if end > total: 

134 end = total 

135 start = max(end - interval_size,0) 

136 intervals.append((start, end)) 

137 start += interval_size - overlap 

138 if end >= total: 

139 break 

140 

141 if check: 

142 assert intervals[0][0] == 0 

143 assert intervals[-1][1] == total 

144 assert len(intervals) == num_intervals, f"Expected {num_intervals} intervals, got {len(intervals)}" 

145 

146 assert interval_size <= max_interval_size, f"Interval size of {interval_size} exceeds max interval size of {max_interval_size}" 

147 for interval in intervals: 

148 assert interval[1] - interval[0] == interval_size, f"Interval size of {interval[1] - interval[0]} is not the expected size {interval_size}" 

149 

150 for i in range(1, len(intervals)): 

151 overlap = intervals[i - 1][1] - intervals[i][0] 

152 assert overlap >= min_overlap, f"Min overlap condition of {min_overlap} not met for intervals {intervals[i - 1]} and {intervals[i]} (overlap {overlap})" 

153 

154 return intervals 

155 

156 

157@dataclass 

158class Embedding(CLIApp, ABC): 

159 """ A class for embedding protein sequences. """ 

160 max_length:int|None=None 

161 overlap:int=64 

162 

163 def __post_init__(self): 

164 super().__init__() 

165 

166 @abstractmethod 

167 def embed(self, seq:str) -> torch.Tensor: 

168 """ Takes a protein sequence as a string and returns an embedding vector. """ 

169 raise NotImplementedError 

170 

171 def reduce(self, tensor:torch.Tensor) -> torch.Tensor: 

172 if tensor.ndim == 2: 

173 tensor = tensor.mean(dim=0) 

174 assert tensor.ndim == 1 

175 return tensor 

176 

177 def __call__(self, seq:str) -> torch.Tensor: 

178 """ Takes a protein sequence as a string and returns an embedding vector. """ 

179 if not self.max_length or len(seq) <= self.max_length: 

180 tensor = self.embed(seq) 

181 return self.reduce(tensor) 

182 

183 epsilon = 0.1 

184 intervals = generate_overlapping_intervals(len(seq), self.max_length, self.overlap) 

185 weights = torch.zeros( (len(seq),), device="cpu" ) 

186 tensor = None 

187 for start,end in intervals: 

188 result = self.embed(seq[start:end]).cpu() 

189 

190 assert result.shape[0] == end-start 

191 embedding_size = result.shape[1] 

192 

193 if tensor is None: 

194 tensor = torch.zeros( (len(seq), embedding_size ), device="cpu") 

195 

196 assert tensor.shape[-1] == embedding_size 

197 

198 interval_indexes = torch.arange(end-start) 

199 distance_from_ends = torch.min( interval_indexes-start, end-interval_indexes-1 ) 

200 

201 weight = epsilon + torch.minimum(distance_from_ends, torch.tensor(self.overlap)) 

202 

203 tensor[start:end] += result * weight.unsqueeze(1) 

204 weights[start:end] += weight 

205 

206 tensor = tensor/weights.unsqueeze(1) 

207 

208 return self.reduce(tensor) 

209 

210 @method 

211 def setup(self, **kwargs): 

212 pass 

213 

214 def build_treedict(self, taxonomy:Path) -> tuple[TreeDict,dict[str,SoftmaxNode]]: 

215 # Create root of tree 

216 lineage_to_node = {} 

217 root = None 

218 

219 # Fill out tree with taxonomy 

220 accession_to_node = {} 

221 with _open(taxonomy) as f: 

222 for line in f: 

223 accesssion, lineage = line.split("\t") 

224 

225 if not root: 

226 root_name = lineage.split(";")[0] 

227 root = SoftmaxNode(root_name) 

228 lineage_to_node[root_name] = root 

229 

230 node = get_node(lineage, lineage_to_node) 

231 accession_to_node[accesssion] = node 

232 

233 treedict = TreeDict(classification_tree=root) 

234 return treedict, accession_to_node 

235 

236 @tool("setup") 

237 def test_lengths( 

238 self, 

239 end:int=5_000, 

240 start:int=1000, 

241 retries:int=5, 

242 **kwargs, 

243 ): 

244 def random_amino_acid_sequence(k): 

245 amino_acids = "ACDEFGHIKLMNPQRSTVWY" # standard 20 amino acids 

246 return ''.join(random.choice(amino_acids) for _ in range(k)) 

247 

248 self.max_length = None 

249 self.setup(**kwargs) 

250 for ii in track(range(start,end)): 

251 for _ in range(retries): 

252 seq = random_amino_acid_sequence(ii) 

253 try: 

254 self(seq) 

255 except Exception as err: 

256 print(f"{ii}: {err}") 

257 return 

258 

259 

260 @tool("setup") 

261 def build_gene_array( 

262 self, 

263 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."), 

264 family_index:int=typer.Option(default=..., help="The index for the gene family to use. E.g. if there are 120 gene families then this should be a number from 0 to 119."), 

265 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."), 

266 flush_every:int=typer.Option(default=5_000, help="An interval to flush the memmap array as it is generated."), 

267 max_length:int=None, 

268 **kwargs, 

269 ): 

270 self.max_length = max_length 

271 self.setup(**kwargs) 

272 

273 assert marker_genes is not None 

274 assert family_index is not None 

275 assert output_dir is not None 

276 

277 dtype = 'float16' 

278 

279 memmap_wip_array = None 

280 output_dir.mkdir(parents=True, exist_ok=True) 

281 memmap_wip_path = output_dir / f"{family_index}-wip.npy" 

282 error = output_dir / f"{family_index}-errors.txt" 

283 accessions_wip = output_dir / f"{family_index}-accessions-wip.txt" 

284 

285 accessions = [] 

286 

287 print(f"Loading {marker_genes} file.") 

288 with tarfile.open(marker_genes, "r:gz") as tar, open(error, "w") as error_file, open(accessions_wip, "w") as accessions_wip_file: 

289 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")] 

290 prefix_length = len(os.path.commonprefix([Path(member.name).with_suffix("").name for member in members])) 

291 

292 member = members[family_index] 

293 print(f"Processing file {family_index} in {marker_genes}") 

294 

295 f = tar.extractfile(member) 

296 marker_id = Path(member.name).with_suffix("").name[prefix_length:] 

297 

298 fasta_io = StringIO(f.read().decode('ascii')) 

299 

300 total = sum(1 for _ in SeqIO.parse(fasta_io, "fasta")) 

301 fasta_io.seek(0) 

302 print(marker_id, total) 

303 

304 for record in track(SeqIO.parse(fasta_io, "fasta"), total=total): 

305 # for record in SeqIO.parse(fasta_io, "fasta"): 

306 species_accession = record.id 

307 

308 key = get_key(species_accession, marker_id) 

309 

310 seq = str(record.seq).replace("-","").replace("*","") 

311 try: 

312 vector = self(seq) 

313 except Exception as err: 

314 print(f"{key} ({len(seq)}): {err}", file=error_file) 

315 print(f"{key} ({len(seq)}): {err}") 

316 continue 

317 

318 if vector is None: 

319 print(f"{key} ({len(seq)}): Embedding is None", file=error_file) 

320 print(f"{key} ({len(seq)}): Embedding is None") 

321 continue 

322 

323 if torch.isnan(vector).any(): 

324 print(f"{key} ({len(seq)}): Embedding contains NaN", file=error_file) 

325 print(f"{key} ({len(seq)}): Embedding contains NaN") 

326 continue 

327 

328 if memmap_wip_array is None: 

329 size = len(vector) 

330 shape = (total,size) 

331 memmap_wip_array = np.memmap(memmap_wip_path, dtype=dtype, mode='w+', shape=shape) 

332 

333 index = len(accessions) 

334 memmap_wip_array[index,:] = vector.cpu().half().numpy() 

335 if index % flush_every == 0: 

336 memmap_wip_array.flush() 

337 

338 accessions.append(key) 

339 print(key, file=accessions_wip_file) 

340 

341 memmap_wip_array.flush() 

342 

343 accessions_path = output_dir / f"{family_index}.txt" 

344 with open(accessions_path, "w") as f: 

345 for accession in accessions: 

346 print(accession, file=f) 

347 

348 # Save final memmap array now that we now the final size 

349 memmap_path = output_dir / f"{family_index}.npy" 

350 shape = (len(accessions),size) 

351 print(f"Writing final memmap array of shape {shape}: {memmap_path}") 

352 memmap_array = np.memmap(memmap_path, dtype=dtype, mode='w+', shape=shape) 

353 memmap_array[:len(accessions),:] = memmap_wip_array[:len(accessions),:] 

354 memmap_array.flush() 

355 

356 # Clean up 

357 memmap_array._mmap.close() 

358 memmap_array._mmap = None 

359 memmap_array = None 

360 memmap_wip_path.unlink() 

361 accessions_wip.unlink() 

362 

363 @tool 

364 def set_validation_rank( 

365 self, 

366 treedict:Path=typer.Option(default=..., help="The path to the treedict file."), 

367 output:Path=typer.Option(default=..., help="The path to save the adapted treedict file."), 

368 validation_rank:str=typer.Option(default="species", help="The rank to hold out for cross-validation."), 

369 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."), 

370 ) -> TreeDict: 

371 treedict = TreeDict.load(treedict) 

372 set_validation_rank_to_treedict(treedict, validation_rank=validation_rank, partitions=partitions) 

373 treedict.save(output) 

374 return treedict 

375 

376 @tool 

377 def preprocess( 

378 self, 

379 taxonomy:Path=typer.Option(default=..., help="The path to the TSV taxonomy file (e.g. bac120_taxonomy_r220.tsv)."), 

380 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."), 

381 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."), 

382 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."), 

383 seed:int=typer.Option(default=42, help="The random seed."), 

384 treedict_only:bool=typer.Option(default=False, help="Only output TreeDict file and then exit before concatenating memmap array"), 

385 ): 

386 treedict, accession_to_node = self.build_treedict(taxonomy) 

387 

388 dtype = 'float16' 

389 

390 random.seed(seed) 

391 

392 print(f"Loading {marker_genes} file.") 

393 with tarfile.open(marker_genes, "r:gz") as tar: 

394 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")] 

395 family_count = len(members) 

396 print(f"{family_count} gene families found.") 

397 

398 # Read and collect accessions 

399 print(f"Building treedict") 

400 keys = [] 

401 counts = [] 

402 node_to_partition_dict = dict() 

403 for family_index in track(range(family_count)): 

404 keys_path = output_dir / f"{family_index}.txt" 

405 

406 if not keys_path.exists(): 

407 counts.append(0) 

408 continue 

409 

410 with open(keys_path) as f: 

411 family_index_keys = [line.strip() for line in f] 

412 keys += family_index_keys 

413 counts.append(len(family_index_keys)) 

414 

415 for key in family_index_keys: 

416 genome_accession = key.split("/")[0] 

417 node = accession_to_node[genome_accession] 

418 partition = node_to_partition_dict.setdefault(node, random.randint(0, partitions - 1)) 

419 

420 # Add to treedict 

421 treedict.add(key, node, partition) 

422 

423 assert len(counts) == family_count 

424 

425 # Save treedict 

426 treedict_path = output_dir / f"{output_dir.name}.td" 

427 print(f"Saving TreeDict to {treedict_path}") 

428 treedict.save(treedict_path) 

429 

430 if treedict_only: 

431 return 

432 

433 # Concatenate numpy memmap arrays 

434 memmap_array = None 

435 memmap_array_path = output_dir / f"{output_dir.name}.npy" 

436 print(f"Saving memmap to {memmap_array_path}") 

437 current_index = 0 

438 for family_index, family_count in track(enumerate(counts), total=len(counts)): 

439 my_memmap_path = output_dir / f"{family_index}.npy" 

440 

441 # Build memmap for gene family if it doesn't exist 

442 if not my_memmap_path.exists(): 

443 continue 

444 # print("Building", my_memmap_path) 

445 # self.build_gene_array(marker_genes=marker_genes, family_index=family_index, output_dir=output_dir) 

446 # assert my_memmap_path.exists() 

447 

448 my_memmap = read_memmap(my_memmap_path, family_count) 

449 

450 # Build memmap for output if it doesn't exist 

451 if memmap_array is None: 

452 size = my_memmap.shape[1] 

453 shape = (len(keys),size) 

454 memmap_array = np.memmap(memmap_array_path, dtype=dtype, mode='w+', shape=shape) 

455 

456 # Copy memmap for gene family into output memmap 

457 memmap_array[current_index:current_index+family_count,:] = my_memmap[:,:] 

458 

459 current_index += family_count 

460 

461 assert len(keys) == current_index 

462 

463 memmap_array.flush() 

464 

465 # Save keys 

466 keys_path = output_dir / f"{output_dir.name}.txt" 

467 print(f"Saving keys to {keys_path}") 

468 with open(keys_path, "w") as f: 

469 for key in keys: 

470 print(key, file=f) 

471 

472 @tool 

473 def prune_to_representatives(treedict:Path, representatives:Path, output:Path): 

474 print("Getting list of representatives from", representatives) 

475 keys_to_keep = [] 

476 with tarfile.open(representatives, "r:gz") as tar: 

477 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")] 

478 

479 print(f"Processing {len(members)} files in {representatives}") 

480 

481 for member in track(members): 

482 f = tar.extractfile(member) 

483 marker_id = Path(member.name.split("_")[-1]).with_suffix("").name 

484 

485 fasta_io = StringIO(f.read().decode('ascii')) 

486 

487 for record in SeqIO.parse(fasta_io, "fasta"): 

488 species_accession = record.id 

489 key = get_key(species_accession, marker_id) 

490 keys_to_keep.append(key) 

491 

492 # keys_to_keep = set(keys_to_keep) 

493 print(f"Keeping {len(keys_to_keep)} representatives") 

494 

495 print(f"Loading treedict {treedict}") 

496 

497 treedict = TreeDict.load(treedict) 

498 print("Total", len(treedict)) 

499 missing = [] 

500 for key in track(keys_to_keep): 

501 if key not in treedict: 

502 missing.append(key) 

503 

504 print(f"{len(missing)} representatives missing output {len(keys_to_keep)} (total: {len(treedict)})") 

505 if len(missing): 

506 keys_to_keep = [k for k in keys_to_keep if k not in missing] 

507 

508 new_treedict = TreeDict(treedict.classification_tree) 

509 new_treedict.update({k:treedict[k] for k in keys_to_keep}) 

510 print("Total after pruning", len(new_treedict)) 

511 

512 print("Saving treedict to", output) 

513 new_treedict.save(output)