From 4acce8f9c366a3569e2425bf1d1e3610ba233bc1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A1t=C3=A9=20Balajti?= <mate.balajti@unibas.ch>
Date: Wed, 9 Aug 2023 13:40:50 +0200
Subject: [PATCH] feat: update setup.py to allow tool install

---
 README.md                      |  3 +--
 setup.py                       |  7 ++++++-
 term_frag_sel/cli.py           | 14 ++++++++------
 term_frag_sel/fragmentation.py |  2 +-
 4 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/README.md b/README.md
index 6b914a4..1099b91 100644
--- a/README.md
+++ b/README.md
@@ -29,8 +29,7 @@ Output:
 To install package, run
 
 ```
-pip install -r requirements.txt
-pip install -r requirements_dev.txt
+pip install .
 ```
 
 
diff --git a/setup.py b/setup.py
index abcfdb7..21a2fa2 100644
--- a/setup.py
+++ b/setup.py
@@ -18,5 +18,10 @@ setup(
     author_email='hmadge@ethz.ch',
     description='Terminal fragment selector',
     packages=find_packages(),
-    install_requires=INSTALL_REQUIRES
+    install_requires=INSTALL_REQUIRES,
+    entry_points={
+        'console_scripts': [
+            'terminal-fragment-selector=term_frag_sel.cli:main'
+            ]
+        }
 )
diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py
index 67b9f7d..37188c9 100644
--- a/term_frag_sel/cli.py
+++ b/term_frag_sel/cli.py
@@ -18,13 +18,15 @@ logging.basicConfig(
 logger = logging.getLogger("main")
 
 
-def main(args: argparse.Namespace):
+def main():
     """Use CLI arguments to fragment sequences and output text file \
     with selected terminal fragments.
 
     Args:
         args (parser): list of arguments from CLI.
     """
+    args = parse_arguments()
+
     if not isinstance(args, argparse.Namespace):
         raise TypeError("Input should be argparse.Namespace")
 
@@ -38,8 +40,10 @@ def main(args: argparse.Namespace):
     logger.info("Fragmentation of %s...", args.fasta)
     splits = np.arange(0, len(list(fasta))+args.size, args.size)
 
-    for i, split in enumerate(splits):
-        fasta_dict = fasta[split:splits[i+1]]
+    for i in range(len(splits) - 1):
+        split = splits[i]
+        keys = list(fasta.keys())[split:splits[i+1]]
+        fasta_dict = {key: fasta[key] for key in keys}
         term_frags = fragmentation(fasta_dict, seq_counts,
                                    args.mean, args.std)
 
@@ -132,6 +136,4 @@ if __name__ == '__main__':
         level=logging.INFO,
     )
     logger = logging.getLogger(__name__)
-
-    arguments = parse_arguments()
-    main(arguments)
+    main()
diff --git a/term_frag_sel/fragmentation.py b/term_frag_sel/fragmentation.py
index 9cd8a0f..f3471a0 100644
--- a/term_frag_sel/fragmentation.py
+++ b/term_frag_sel/fragmentation.py
@@ -24,7 +24,7 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
 
     term_frags = []
     for seq_id, seq in fasta.items():
-        counts = seq_counts[seq_counts["seqID"] == seq_id]["count"]
+        counts = seq_counts[seq_counts["seqID"] == str(seq_id)]["count"]
         for _ in range(counts.iloc[0]):
             cuts = []
             seq_len = len(seq)
-- 
GitLab