Coverage for barbet/apps.py: 45.42%

240 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-12 04:23 +0000

1from typing import TYPE_CHECKING 

2from pathlib import Path 

3from enum import Enum 

4from collections import defaultdict 

5from rich.console import Console 

6from rich.progress import track 

7from torchapp import TorchApp, Param, method, main, tool 

8 

9from .output import print_polars_df 

10 

11if TYPE_CHECKING: 

12 from collections.abc import Iterable 

13 from torchmetrics import Metric 

14 from hierarchicalsoftmax import SoftmaxNode 

15 from torch import nn 

16 import lightning as L 

17 # import pandas as pd 

18 import polars as pl 

19 

20 

21console = Console() 

22 

23 

24class ImageFormat(str, Enum): 

25 """The image format to use for the output images.""" 

26 

27 NONE = "" 

28 PNG = "png" 

29 JPG = "jpg" 

30 SVG = "svg" 

31 PDF = "pdf" 

32 DOT = "dot" 

33 

34 def __str__(self): 

35 return self.value 

36 

37 def __bool__(self) -> bool: 

38 """Returns True if the image format is not empty.""" 

39 return self.value != "" 

40 

41 

42class Barbet(TorchApp): 

43 @method 

44 def setup( 

45 self, 

46 memmap: str = None, 

47 memmap_index: str = None, 

48 treedict: str = None, 

49 stack_size: int = 32, 

50 in_memory: bool = False, 

51 tip_alpha: float = None, 

52 ) -> None: 

53 if not treedict: 

54 raise ValueError("treedict is required") 

55 if not memmap: 

56 raise ValueError("memmap is required") 

57 if not memmap_index: 

58 raise ValueError("memmap_index is required") 

59 

60 from hierarchicalsoftmax import TreeDict 

61 import numpy as np 

62 from barbet.data import read_memmap 

63 

64 self.stack_size = stack_size 

65 

66 print(f"Loading treedict {treedict}") 

67 individual_treedict = TreeDict.load(treedict) 

68 self.treedict = TreeDict( 

69 classification_tree=individual_treedict.classification_tree 

70 ) 

71 

72 # Sets the loss weighting for the tips 

73 if tip_alpha: 

74 for tip in self.treedict.classification_tree.leaves: 

75 tip.parent.alpha = tip_alpha 

76 

77 print("Loading memmap") 

78 self.accession_to_array_index = defaultdict(list) 

79 with open(memmap_index) as f: 

80 for key_index, key in enumerate(f): 

81 key = key.strip() 

82 accession = key.strip().split("/")[0] 

83 

84 if len(self.accession_to_array_index[accession]) == 0: 

85 self.treedict[accession] = individual_treedict[key] 

86 

87 self.accession_to_array_index[accession].append(key_index) 

88 count = key_index + 1 

89 self.array = read_memmap(memmap, count) 

90 

91 # If there's enough memory, then read into RAM 

92 if in_memory: 

93 self.array = np.array(self.array) 

94 

95 self.classification_tree = self.treedict.classification_tree 

96 assert self.classification_tree is not None 

97 

98 # Get list of gene families 

99 family_ids = set() 

100 for accession in self.treedict: 

101 gene_id = accession.split("/")[-1] 

102 family_ids.add(gene_id) 

103 

104 @method 

105 def model( 

106 self, 

107 features: int = 768, 

108 intermediate_layers: int = 2, 

109 growth_factor: float = 2.0, 

110 attention_size: int = 512, 

111 ) -> "nn.Module": 

112 from barbet.models import BarbetModel 

113 

114 return BarbetModel( 

115 classification_tree=self.classification_tree, 

116 features=features, 

117 intermediate_layers=intermediate_layers, 

118 growth_factor=growth_factor, 

119 attention_size=attention_size, 

120 ) 

121 

122 @method 

123 def loss_function(self): 

124 from hierarchicalsoftmax import HierarchicalSoftmaxLoss 

125 

126 return HierarchicalSoftmaxLoss(root=self.classification_tree) 

127 

128 @method 

129 def metrics(self) -> "list[tuple[str,Metric]]": 

130 from hierarchicalsoftmax.metrics import RankAccuracyTorchMetric 

131 from barbet.data import RANKS 

132 

133 rank_accuracy = RankAccuracyTorchMetric( 

134 root=self.classification_tree, 

135 ranks={1 + i: rank for i, rank in enumerate(RANKS)}, 

136 ) 

137 

138 return [("rank_accuracy", rank_accuracy)] 

139 

140 @method 

141 def data( 

142 self, 

143 max_items: int = 0, 

144 num_workers: int = 4, 

145 validation_partition: int = 0, 

146 batch_size: int = 4, 

147 test_partition: int = -1, 

148 train_all: bool = False, 

149 ) -> "Iterable|L.LightningDataModule": 

150 from barbet.data import BarbetDataModule 

151 

152 return BarbetDataModule( 

153 array=self.array, 

154 accession_to_array_index=self.accession_to_array_index, 

155 treedict=self.treedict, 

156 max_items=max_items, 

157 batch_size=batch_size, 

158 num_workers=num_workers, 

159 validation_partition=validation_partition, 

160 test_partition=test_partition, 

161 stack_size=self.stack_size, 

162 train_all=train_all, 

163 ) 

164 

165 @method 

166 def module_class(self) : 

167 from .modules import BarbetLightningModule 

168 return BarbetLightningModule 

169 

170 @method 

171 def extra_hyperparameters(self, embedding_model: str = "") -> dict: 

172 """Extra hyperparameters to save with the module.""" 

173 assert embedding_model, "Please provide an embedding model." 

174 from barbet.embeddings.esm import ESMEmbedding 

175 

176 embedding_model = embedding_model.lower() 

177 if embedding_model.startswith("esm"): 

178 layers = embedding_model[3:].strip() 

179 embedding_model = ESMEmbedding() 

180 embedding_model.setup(layers=layers) 

181 else: 

182 raise ValueError(f"Cannot understand embedding model: {embedding_model}") 

183 

184 return dict( 

185 embedding_model=embedding_model, 

186 classification_tree=self.treedict.classification_tree, 

187 stack_size=self.stack_size, 

188 ) 

189 

190 @method 

191 def prediction_dataloader( 

192 self, 

193 module, 

194 genome_path: Path, 

195 markers: dict[str, str], 

196 batch_size: int = Param( 

197 64, help="The batch size for the prediction dataloader." 

198 ), 

199 cpus: int = Param( 

200 1, help="The number of CPUs to use for the prediction dataloader." 

201 ), 

202 dataloader_workers: int = Param( 

203 4, help="The number of workers to use for the dataloader." 

204 ), 

205 repeats: int = Param( 

206 2, 

207 help="The minimum number of times to use each protein embedding in the prediction.", 

208 ), 

209 **kwargs, 

210 ) -> "Iterable": 

211 import torch 

212 import numpy as np 

213 from torch.utils.data import DataLoader 

214 from barbet.data import BarbetPredictionDataset 

215 

216 # Set PyTorch thread limits 

217 torch.set_num_threads(cpus) 

218 

219 # Get hyperparameters from checkpoint 

220 stack_size = module.hparams.get("stack_size", 32) 

221 self.classification_tree = module.hparams.classification_tree 

222 

223 # extract domain from the model 

224 domain = "ar53" if self.classification_tree.name == "d__Archaea" else "bac120" 

225 

226 ####################### 

227 # Create Embeddings 

228 ####################### 

229 embeddings = [] 

230 accessions = [] 

231 

232 fastas = markers[domain] 

233 for fasta in track( 

234 fastas, description="[cyan]Embedding... ", total=len(fastas) 

235 ): 

236 # read the fasta file sequence remove the header 

237 fasta = Path(fasta) 

238 seq = fasta.read_text().split("\n")[1] 

239 vector = module.hparams.embedding_model(seq) 

240 if vector is not None and not torch.isnan(vector).any(): 

241 vector = vector.cpu().detach().clone().numpy() 

242 embeddings.append(vector) 

243 

244 gene_family_id = fasta.stem 

245 accession = f"{genome_path.stem}/{gene_family_id}" 

246 accessions.append(accession) 

247 

248 del vector 

249 

250 embeddings = np.asarray(embeddings).astype(np.float16) 

251 

252 self.prediction_dataset = BarbetPredictionDataset( 

253 array=embeddings, 

254 accessions=accessions, 

255 stack_size=stack_size, 

256 repeats=repeats, 

257 seed=42, 

258 ) 

259 dataloader = DataLoader( 

260 self.prediction_dataset, 

261 batch_size=batch_size, 

262 num_workers=dataloader_workers, 

263 shuffle=False, 

264 ) 

265 

266 return dataloader 

267 

268 def node_to_str(self, node: "SoftmaxNode") -> str: 

269 """ 

270 Converts the node to a string 

271 """ 

272 return str(node).split(",")[-1].strip() 

273 

274 @main( 

275 "load_checkpoint", 

276 "prediction_trainer", 

277 "prediction_dataloader", 

278 ) 

279 def predict( 

280 self, 

281 input: list[Path] = Param( 

282 default=..., 

283 help="FASTA files or directories of FASTA files. Requires genome in an individual FASTA file." 

284 ), 

285 output_dir: Path = Param("output", help="A path to the output directory."), 

286 output_csv: Path = Param( 

287 default=None, help="A path to output the results as a CSV." 

288 ), 

289 cpus: int = Param( 

290 1, help="The number of CPUs to use." 

291 ), 

292 pfam_db: str = Param( 

293 "https://data.ace.uq.edu.au/public/gtdbtk/release95/markers/pfam/Pfam-A.hmm", 

294 help="The Pfam database to use.", 

295 ), 

296 tigr_db: str = Param( 

297 "https://data.ace.uq.edu.au/public/gtdbtk/release95/markers/tigrfam/tigrfam.hmm", 

298 help="The TIGRFAM database to use.", 

299 ), 

300 **kwargs, 

301 ): 

302 """Barbet is a tool for assigning taxonomic labels to genomes using Machine Learning.""" 

303 # import pandas as pd 

304 import polars as pl 

305 from itertools import chain 

306 from barbet.markers import extract_markers_genes 

307 

308 # Get list of files 

309 files = [] 

310 if isinstance(input, (str, Path)): 

311 input = [input] 

312 assert len(input) > 0, "No input files provided." 

313 for path in input: 

314 if path.is_dir(): 

315 for file in chain( 

316 path.rglob("*.fa"), 

317 path.rglob("*.fasta"), 

318 path.rglob("*.fna"), 

319 path.rglob("*.fa.gz"), 

320 path.rglob("*.fasta.gz"), 

321 path.rglob("*.fna.gz"), 

322 ): 

323 files.append(file) 

324 elif path.is_file(): 

325 files.append(path) 

326 

327 # Check if any files were found 

328 if len(files) == 0: 

329 raise ValueError( 

330 f"No files found in {input}. Please provide a directory or a list of files." 

331 ) 

332 

333 # Check if output directory exists 

334 self.output_dir = Path(output_dir) 

335 output_csv = output_csv or self.output_dir / "barbet-predictions.csv" 

336 output_csv = Path(output_csv) 

337 output_csv.parent.mkdir(exist_ok=True, parents=True) 

338 console.print( 

339 f"Writing results for {len(files)} genome{'s' if len(files) > 1 else ''} to '{output_csv}'" 

340 ) 

341 

342 #################### 

343 # Extract single copy marker genes 

344 #################### 

345 markers_gene_map = extract_markers_genes( 

346 genomes={file.stem: str(file) for file in files}, 

347 out_dir=str(self.output_dir), 

348 cpus=cpus, 

349 force=True, 

350 pfam_db=self.process_location(pfam_db), 

351 tigr_db=self.process_location(tigr_db), 

352 ) 

353 

354 # Load the model 

355 module = self.load_checkpoint(**kwargs) 

356 trainer = self.prediction_trainer(module, **kwargs) 

357 

358 # Make predictions for each file 

359 total_df = None 

360 for genome_path, maker_genes in markers_gene_map.items(): 

361 genome_path = Path(genome_path) 

362 prediction_dataloader = self.prediction_dataloader(module, genome_path, maker_genes, cpus=cpus, **kwargs) 

363 module.setup_prediction(self, genome_path.name) 

364 trainer.predict(module, dataloaders=prediction_dataloader) 

365 results_df = module.results_df 

366 

367 if total_df is None: 

368 total_df = results_df 

369 if output_csv: 

370 results_df.write_csv(output_csv) 

371 else: 

372 total_df = pl.concat([total_df, results_df], how="vertical") 

373 

374 if output_csv: 

375 with open(output_csv, mode="a") as f: 

376 results_df.write_csv(f, include_header=False) 

377 

378 print_polars_df( 

379 total_df[["name", "species_prediction", "species_probability", ]], 

380 column_names=["Genome", "Species", "Probability"], 

381 ) 

382 console.print(f"Saved to: '{output_csv}'") 

383 return total_df 

384 

385 @tool( 

386 "load_checkpoint", 

387 "prediction_trainer", 

388 "prediction_dataloader_memmap", 

389 ) 

390 def predict_memmap( 

391 self, 

392 output_csv: Path = Param( 

393 default=None, help="A path to output the results as a CSV." 

394 ), 

395 treedict:Path = Param(None, help="A path to a TreeDict with the ground truth lineage."), 

396 probabilities: bool = Param( 

397 default=False, help="If True, include probabilities for all the nodes in the taxonomic tree." 

398 ), 

399 **kwargs, 

400 ): 

401 """Barbet is a tool for assigning taxonomic labels to genomes using Machine Learning.""" 

402 module = self.load_checkpoint(**kwargs) 

403 trainer = self.prediction_trainer(module, **kwargs) 

404 prediction_dataloader = self.prediction_dataloader_memmap(module, **kwargs) 

405 

406 module.setup_prediction(self, [stack.genome for stack in self.prediction_dataset.stacks], save_probabilities=probabilities) 

407 trainer.predict(module, dataloaders=prediction_dataloader, return_predictions=False) 

408 results_df = module.results_df 

409 

410 genome_name_set = set(results_df['name'].unique()) 

411 

412 if treedict is not None: 

413 from hierarchicalsoftmax import TreeDict 

414 from barbet.data import RANKS 

415 import polars as pl 

416 

417 true_values = defaultdict(dict) 

418 

419 console.print(f"Adding true values from TreeDict '{treedict}'") 

420 treedict = TreeDict.load(treedict) 

421 

422 # Get lineage to map 

423 for accession in track(treedict.keys()): 

424 genome_name = accession.split("/")[0] 

425 if genome_name in genome_name_set: 

426 node = treedict.node(accession) 

427 lineage = node.ancestors[1:] + (node,) 

428 for rank, lineage_node in zip(RANKS, lineage): 

429 true_values[rank][genome_name] = lineage_node.name.strip() 

430 

431 

432 for rank in RANKS: 

433 results_df = results_df.with_columns( 

434 pl.col("name").map_elements(true_values[rank].get, return_dtype=pl.Utf8).alias(f"{rank}_true") 

435 ) 

436 

437 console.print(f"Writing to '{output_csv}'") 

438 output_csv = Path(output_csv) 

439 output_csv.parent.mkdir(exist_ok=True, parents=True) 

440 results_df.write_csv(output_csv) 

441 

442 return results_df 

443 

444 @method 

445 def prediction_dataloader_memmap( 

446 self, 

447 module, 

448 memmap:Path = Param(None, help="A path to the memmap file containing the protein embeddings."), 

449 memmap_index:Path = Param(None, help="A path to the memmap index file containing the accessions."), 

450 batch_size: int = Param( 

451 64, help="The batch size for the prediction dataloader." 

452 ), 

453 num_workers: int = 4, 

454 repeats: int = Param( 

455 2, 

456 help="The minimum number of times to use each protein embedding in the prediction.", 

457 ), 

458 genomes:Path=Param(None, help="A path to a text file with the accessions for the genome to use."), 

459 **kwargs, 

460 ) -> "Iterable": 

461 from barbet.data import read_memmap 

462 from torch.utils.data import DataLoader 

463 from barbet.data import BarbetPredictionDataset 

464 

465 assert memmap is not None, "Please provide a path to the memmap file." 

466 assert memmap.exists(), f"Memmap file does not exist: {memmap}" 

467 assert memmap_index is not None, "Please provide a path to the memmap index file." 

468 assert memmap_index.exists(), f"Memmap index file does not exist: {memmap_index}" 

469 

470 # Read the memmap array index 

471 console.print(f"Reading memmap array index '{memmap_index}'") 

472 accessions = memmap_index.read_text().strip().split("\n") 

473 count = len(accessions) 

474 console.print(f"Found {count} accessions") 

475 

476 # Load the memmap array itself 

477 console.print(f"Loading memmap array '{memmap}'") 

478 array = read_memmap(memmap, count) 

479 

480 # Get hyperparameters from checkpoint 

481 self.classification_tree = module.hparams.classification_tree 

482 stack_size = module.hparams.get("stack_size", 32) 

483 

484 # If treedict is provided, then we filter the accessions to only those that are in the treedict 

485 genome_filter = None 

486 if genomes: 

487 assert genomes.exists(), f"Genomes file does not exist: {genomes}" 

488 genome_filter = set(Path(genomes).read_text().strip().split("\n")) 

489 

490 self.prediction_dataset = BarbetPredictionDataset( 

491 array=array, 

492 accessions=accessions, 

493 stack_size=stack_size, 

494 repeats=repeats, 

495 genome_filter=genome_filter, 

496 seed=42, 

497 ) 

498 dataloader = DataLoader( 

499 self.prediction_dataset, 

500 batch_size=batch_size, 

501 num_workers=num_workers, 

502 shuffle=False, 

503 ) 

504 

505 return dataloader 

506 

507 @method 

508 def monitor( 

509 self, 

510 train_all: bool = False, 

511 **kwargs, 

512 ) -> str: 

513 if train_all: 

514 return "valid_loss" 

515 return "genus" 

516 

517 def checkpoint( 

518 self, 

519 checkpoint:Path=Param(None, help="The path to a checkpoint file for the Barbet parameters. If not provided, then it will use a standard checkpoint."), 

520 large:bool=Param(False, help="Whether or not to use the large standard checkpoint of the Barbet parameters."), 

521 archaea:bool=Param(False, help="Whether or not to use the standard model for archaea. If not, then it uses the default model for bacteria."), 

522 ) -> str: 

523 if checkpoint: 

524 return checkpoint 

525 

526 # Weights are here: https://figshare.unimelb.edu.au/articles/dataset/Trained_weights_for_Barbet/ 

527 # DOI: https://doi.org/10.26188/29578964 

528 

529 if archaea: 

530 if large: 

531 # barbet-ar53-ESM12-large.ckpt 

532 return "https://figshare.unimelb.edu.au/ndownloader/files/56332160" 

533 

534 # barbet-ar53-ESM12-base.ckpt 

535 return "https://figshare.unimelb.edu.au/ndownloader/files/56332157" 

536 

537 if large: 

538 # barbet-bac120-ESM6-large.ckpt 

539 return "https://figshare.unimelb.edu.au/ndownloader/files/56307647" 

540 

541 # barbet-bac120-ESM6-base.ckpt 

542 return "https://figshare.unimelb.edu.au/ndownloader/files/56307671" 

543