Coverage for barbet/apps.py: 45.42%
240 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
1from typing import TYPE_CHECKING
2from pathlib import Path
3from enum import Enum
4from collections import defaultdict
5from rich.console import Console
6from rich.progress import track
7from torchapp import TorchApp, Param, method, main, tool
9from .output import print_polars_df
11if TYPE_CHECKING:
12 from collections.abc import Iterable
13 from torchmetrics import Metric
14 from hierarchicalsoftmax import SoftmaxNode
15 from torch import nn
16 import lightning as L
17 # import pandas as pd
18 import polars as pl
21console = Console()
24class ImageFormat(str, Enum):
25 """The image format to use for the output images."""
27 NONE = ""
28 PNG = "png"
29 JPG = "jpg"
30 SVG = "svg"
31 PDF = "pdf"
32 DOT = "dot"
34 def __str__(self):
35 return self.value
37 def __bool__(self) -> bool:
38 """Returns True if the image format is not empty."""
39 return self.value != ""
42class Barbet(TorchApp):
43 @method
44 def setup(
45 self,
46 memmap: str = None,
47 memmap_index: str = None,
48 treedict: str = None,
49 stack_size: int = 32,
50 in_memory: bool = False,
51 tip_alpha: float = None,
52 ) -> None:
53 if not treedict:
54 raise ValueError("treedict is required")
55 if not memmap:
56 raise ValueError("memmap is required")
57 if not memmap_index:
58 raise ValueError("memmap_index is required")
60 from hierarchicalsoftmax import TreeDict
61 import numpy as np
62 from barbet.data import read_memmap
64 self.stack_size = stack_size
66 print(f"Loading treedict {treedict}")
67 individual_treedict = TreeDict.load(treedict)
68 self.treedict = TreeDict(
69 classification_tree=individual_treedict.classification_tree
70 )
72 # Sets the loss weighting for the tips
73 if tip_alpha:
74 for tip in self.treedict.classification_tree.leaves:
75 tip.parent.alpha = tip_alpha
77 print("Loading memmap")
78 self.accession_to_array_index = defaultdict(list)
79 with open(memmap_index) as f:
80 for key_index, key in enumerate(f):
81 key = key.strip()
82 accession = key.strip().split("/")[0]
84 if len(self.accession_to_array_index[accession]) == 0:
85 self.treedict[accession] = individual_treedict[key]
87 self.accession_to_array_index[accession].append(key_index)
88 count = key_index + 1
89 self.array = read_memmap(memmap, count)
91 # If there's enough memory, then read into RAM
92 if in_memory:
93 self.array = np.array(self.array)
95 self.classification_tree = self.treedict.classification_tree
96 assert self.classification_tree is not None
98 # Get list of gene families
99 family_ids = set()
100 for accession in self.treedict:
101 gene_id = accession.split("/")[-1]
102 family_ids.add(gene_id)
104 @method
105 def model(
106 self,
107 features: int = 768,
108 intermediate_layers: int = 2,
109 growth_factor: float = 2.0,
110 attention_size: int = 512,
111 ) -> "nn.Module":
112 from barbet.models import BarbetModel
114 return BarbetModel(
115 classification_tree=self.classification_tree,
116 features=features,
117 intermediate_layers=intermediate_layers,
118 growth_factor=growth_factor,
119 attention_size=attention_size,
120 )
122 @method
123 def loss_function(self):
124 from hierarchicalsoftmax import HierarchicalSoftmaxLoss
126 return HierarchicalSoftmaxLoss(root=self.classification_tree)
128 @method
129 def metrics(self) -> "list[tuple[str,Metric]]":
130 from hierarchicalsoftmax.metrics import RankAccuracyTorchMetric
131 from barbet.data import RANKS
133 rank_accuracy = RankAccuracyTorchMetric(
134 root=self.classification_tree,
135 ranks={1 + i: rank for i, rank in enumerate(RANKS)},
136 )
138 return [("rank_accuracy", rank_accuracy)]
140 @method
141 def data(
142 self,
143 max_items: int = 0,
144 num_workers: int = 4,
145 validation_partition: int = 0,
146 batch_size: int = 4,
147 test_partition: int = -1,
148 train_all: bool = False,
149 ) -> "Iterable|L.LightningDataModule":
150 from barbet.data import BarbetDataModule
152 return BarbetDataModule(
153 array=self.array,
154 accession_to_array_index=self.accession_to_array_index,
155 treedict=self.treedict,
156 max_items=max_items,
157 batch_size=batch_size,
158 num_workers=num_workers,
159 validation_partition=validation_partition,
160 test_partition=test_partition,
161 stack_size=self.stack_size,
162 train_all=train_all,
163 )
165 @method
166 def module_class(self) :
167 from .modules import BarbetLightningModule
168 return BarbetLightningModule
170 @method
171 def extra_hyperparameters(self, embedding_model: str = "") -> dict:
172 """Extra hyperparameters to save with the module."""
173 assert embedding_model, "Please provide an embedding model."
174 from barbet.embeddings.esm import ESMEmbedding
176 embedding_model = embedding_model.lower()
177 if embedding_model.startswith("esm"):
178 layers = embedding_model[3:].strip()
179 embedding_model = ESMEmbedding()
180 embedding_model.setup(layers=layers)
181 else:
182 raise ValueError(f"Cannot understand embedding model: {embedding_model}")
184 return dict(
185 embedding_model=embedding_model,
186 classification_tree=self.treedict.classification_tree,
187 stack_size=self.stack_size,
188 )
190 @method
191 def prediction_dataloader(
192 self,
193 module,
194 genome_path: Path,
195 markers: dict[str, str],
196 batch_size: int = Param(
197 64, help="The batch size for the prediction dataloader."
198 ),
199 cpus: int = Param(
200 1, help="The number of CPUs to use for the prediction dataloader."
201 ),
202 dataloader_workers: int = Param(
203 4, help="The number of workers to use for the dataloader."
204 ),
205 repeats: int = Param(
206 2,
207 help="The minimum number of times to use each protein embedding in the prediction.",
208 ),
209 **kwargs,
210 ) -> "Iterable":
211 import torch
212 import numpy as np
213 from torch.utils.data import DataLoader
214 from barbet.data import BarbetPredictionDataset
216 # Set PyTorch thread limits
217 torch.set_num_threads(cpus)
219 # Get hyperparameters from checkpoint
220 stack_size = module.hparams.get("stack_size", 32)
221 self.classification_tree = module.hparams.classification_tree
223 # extract domain from the model
224 domain = "ar53" if self.classification_tree.name == "d__Archaea" else "bac120"
226 #######################
227 # Create Embeddings
228 #######################
229 embeddings = []
230 accessions = []
232 fastas = markers[domain]
233 for fasta in track(
234 fastas, description="[cyan]Embedding... ", total=len(fastas)
235 ):
236 # read the fasta file sequence remove the header
237 fasta = Path(fasta)
238 seq = fasta.read_text().split("\n")[1]
239 vector = module.hparams.embedding_model(seq)
240 if vector is not None and not torch.isnan(vector).any():
241 vector = vector.cpu().detach().clone().numpy()
242 embeddings.append(vector)
244 gene_family_id = fasta.stem
245 accession = f"{genome_path.stem}/{gene_family_id}"
246 accessions.append(accession)
248 del vector
250 embeddings = np.asarray(embeddings).astype(np.float16)
252 self.prediction_dataset = BarbetPredictionDataset(
253 array=embeddings,
254 accessions=accessions,
255 stack_size=stack_size,
256 repeats=repeats,
257 seed=42,
258 )
259 dataloader = DataLoader(
260 self.prediction_dataset,
261 batch_size=batch_size,
262 num_workers=dataloader_workers,
263 shuffle=False,
264 )
266 return dataloader
268 def node_to_str(self, node: "SoftmaxNode") -> str:
269 """
270 Converts the node to a string
271 """
272 return str(node).split(",")[-1].strip()
274 @main(
275 "load_checkpoint",
276 "prediction_trainer",
277 "prediction_dataloader",
278 )
279 def predict(
280 self,
281 input: list[Path] = Param(
282 default=...,
283 help="FASTA files or directories of FASTA files. Requires genome in an individual FASTA file."
284 ),
285 output_dir: Path = Param("output", help="A path to the output directory."),
286 output_csv: Path = Param(
287 default=None, help="A path to output the results as a CSV."
288 ),
289 cpus: int = Param(
290 1, help="The number of CPUs to use."
291 ),
292 pfam_db: str = Param(
293 "https://data.ace.uq.edu.au/public/gtdbtk/release95/markers/pfam/Pfam-A.hmm",
294 help="The Pfam database to use.",
295 ),
296 tigr_db: str = Param(
297 "https://data.ace.uq.edu.au/public/gtdbtk/release95/markers/tigrfam/tigrfam.hmm",
298 help="The TIGRFAM database to use.",
299 ),
300 **kwargs,
301 ):
302 """Barbet is a tool for assigning taxonomic labels to genomes using Machine Learning."""
303 # import pandas as pd
304 import polars as pl
305 from itertools import chain
306 from barbet.markers import extract_markers_genes
308 # Get list of files
309 files = []
310 if isinstance(input, (str, Path)):
311 input = [input]
312 assert len(input) > 0, "No input files provided."
313 for path in input:
314 if path.is_dir():
315 for file in chain(
316 path.rglob("*.fa"),
317 path.rglob("*.fasta"),
318 path.rglob("*.fna"),
319 path.rglob("*.fa.gz"),
320 path.rglob("*.fasta.gz"),
321 path.rglob("*.fna.gz"),
322 ):
323 files.append(file)
324 elif path.is_file():
325 files.append(path)
327 # Check if any files were found
328 if len(files) == 0:
329 raise ValueError(
330 f"No files found in {input}. Please provide a directory or a list of files."
331 )
333 # Check if output directory exists
334 self.output_dir = Path(output_dir)
335 output_csv = output_csv or self.output_dir / "barbet-predictions.csv"
336 output_csv = Path(output_csv)
337 output_csv.parent.mkdir(exist_ok=True, parents=True)
338 console.print(
339 f"Writing results for {len(files)} genome{'s' if len(files) > 1 else ''} to '{output_csv}'"
340 )
342 ####################
343 # Extract single copy marker genes
344 ####################
345 markers_gene_map = extract_markers_genes(
346 genomes={file.stem: str(file) for file in files},
347 out_dir=str(self.output_dir),
348 cpus=cpus,
349 force=True,
350 pfam_db=self.process_location(pfam_db),
351 tigr_db=self.process_location(tigr_db),
352 )
354 # Load the model
355 module = self.load_checkpoint(**kwargs)
356 trainer = self.prediction_trainer(module, **kwargs)
358 # Make predictions for each file
359 total_df = None
360 for genome_path, maker_genes in markers_gene_map.items():
361 genome_path = Path(genome_path)
362 prediction_dataloader = self.prediction_dataloader(module, genome_path, maker_genes, cpus=cpus, **kwargs)
363 module.setup_prediction(self, genome_path.name)
364 trainer.predict(module, dataloaders=prediction_dataloader)
365 results_df = module.results_df
367 if total_df is None:
368 total_df = results_df
369 if output_csv:
370 results_df.write_csv(output_csv)
371 else:
372 total_df = pl.concat([total_df, results_df], how="vertical")
374 if output_csv:
375 with open(output_csv, mode="a") as f:
376 results_df.write_csv(f, include_header=False)
378 print_polars_df(
379 total_df[["name", "species_prediction", "species_probability", ]],
380 column_names=["Genome", "Species", "Probability"],
381 )
382 console.print(f"Saved to: '{output_csv}'")
383 return total_df
385 @tool(
386 "load_checkpoint",
387 "prediction_trainer",
388 "prediction_dataloader_memmap",
389 )
390 def predict_memmap(
391 self,
392 output_csv: Path = Param(
393 default=None, help="A path to output the results as a CSV."
394 ),
395 treedict:Path = Param(None, help="A path to a TreeDict with the ground truth lineage."),
396 probabilities: bool = Param(
397 default=False, help="If True, include probabilities for all the nodes in the taxonomic tree."
398 ),
399 **kwargs,
400 ):
401 """Barbet is a tool for assigning taxonomic labels to genomes using Machine Learning."""
402 module = self.load_checkpoint(**kwargs)
403 trainer = self.prediction_trainer(module, **kwargs)
404 prediction_dataloader = self.prediction_dataloader_memmap(module, **kwargs)
406 module.setup_prediction(self, [stack.genome for stack in self.prediction_dataset.stacks], save_probabilities=probabilities)
407 trainer.predict(module, dataloaders=prediction_dataloader, return_predictions=False)
408 results_df = module.results_df
410 genome_name_set = set(results_df['name'].unique())
412 if treedict is not None:
413 from hierarchicalsoftmax import TreeDict
414 from barbet.data import RANKS
415 import polars as pl
417 true_values = defaultdict(dict)
419 console.print(f"Adding true values from TreeDict '{treedict}'")
420 treedict = TreeDict.load(treedict)
422 # Get lineage to map
423 for accession in track(treedict.keys()):
424 genome_name = accession.split("/")[0]
425 if genome_name in genome_name_set:
426 node = treedict.node(accession)
427 lineage = node.ancestors[1:] + (node,)
428 for rank, lineage_node in zip(RANKS, lineage):
429 true_values[rank][genome_name] = lineage_node.name.strip()
432 for rank in RANKS:
433 results_df = results_df.with_columns(
434 pl.col("name").map_elements(true_values[rank].get, return_dtype=pl.Utf8).alias(f"{rank}_true")
435 )
437 console.print(f"Writing to '{output_csv}'")
438 output_csv = Path(output_csv)
439 output_csv.parent.mkdir(exist_ok=True, parents=True)
440 results_df.write_csv(output_csv)
442 return results_df
444 @method
445 def prediction_dataloader_memmap(
446 self,
447 module,
448 memmap:Path = Param(None, help="A path to the memmap file containing the protein embeddings."),
449 memmap_index:Path = Param(None, help="A path to the memmap index file containing the accessions."),
450 batch_size: int = Param(
451 64, help="The batch size for the prediction dataloader."
452 ),
453 num_workers: int = 4,
454 repeats: int = Param(
455 2,
456 help="The minimum number of times to use each protein embedding in the prediction.",
457 ),
458 genomes:Path=Param(None, help="A path to a text file with the accessions for the genome to use."),
459 **kwargs,
460 ) -> "Iterable":
461 from barbet.data import read_memmap
462 from torch.utils.data import DataLoader
463 from barbet.data import BarbetPredictionDataset
465 assert memmap is not None, "Please provide a path to the memmap file."
466 assert memmap.exists(), f"Memmap file does not exist: {memmap}"
467 assert memmap_index is not None, "Please provide a path to the memmap index file."
468 assert memmap_index.exists(), f"Memmap index file does not exist: {memmap_index}"
470 # Read the memmap array index
471 console.print(f"Reading memmap array index '{memmap_index}'")
472 accessions = memmap_index.read_text().strip().split("\n")
473 count = len(accessions)
474 console.print(f"Found {count} accessions")
476 # Load the memmap array itself
477 console.print(f"Loading memmap array '{memmap}'")
478 array = read_memmap(memmap, count)
480 # Get hyperparameters from checkpoint
481 self.classification_tree = module.hparams.classification_tree
482 stack_size = module.hparams.get("stack_size", 32)
484 # If treedict is provided, then we filter the accessions to only those that are in the treedict
485 genome_filter = None
486 if genomes:
487 assert genomes.exists(), f"Genomes file does not exist: {genomes}"
488 genome_filter = set(Path(genomes).read_text().strip().split("\n"))
490 self.prediction_dataset = BarbetPredictionDataset(
491 array=array,
492 accessions=accessions,
493 stack_size=stack_size,
494 repeats=repeats,
495 genome_filter=genome_filter,
496 seed=42,
497 )
498 dataloader = DataLoader(
499 self.prediction_dataset,
500 batch_size=batch_size,
501 num_workers=num_workers,
502 shuffle=False,
503 )
505 return dataloader
507 @method
508 def monitor(
509 self,
510 train_all: bool = False,
511 **kwargs,
512 ) -> str:
513 if train_all:
514 return "valid_loss"
515 return "genus"
517 def checkpoint(
518 self,
519 checkpoint:Path=Param(None, help="The path to a checkpoint file for the Barbet parameters. If not provided, then it will use a standard checkpoint."),
520 large:bool=Param(False, help="Whether or not to use the large standard checkpoint of the Barbet parameters."),
521 archaea:bool=Param(False, help="Whether or not to use the standard model for archaea. If not, then it uses the default model for bacteria."),
522 ) -> str:
523 if checkpoint:
524 return checkpoint
526 # Weights are here: https://figshare.unimelb.edu.au/articles/dataset/Trained_weights_for_Barbet/
527 # DOI: https://doi.org/10.26188/29578964
529 if archaea:
530 if large:
531 # barbet-ar53-ESM12-large.ckpt
532 return "https://figshare.unimelb.edu.au/ndownloader/files/56332160"
534 # barbet-ar53-ESM12-base.ckpt
535 return "https://figshare.unimelb.edu.au/ndownloader/files/56332157"
537 if large:
538 # barbet-bac120-ESM6-large.ckpt
539 return "https://figshare.unimelb.edu.au/ndownloader/files/56307647"
541 # barbet-bac120-ESM6-base.ckpt
542 return "https://figshare.unimelb.edu.au/ndownloader/files/56307671"