#! /usr/bin/python3
# -*- mode: python; coding: utf-8 -*-
#
# filtersum
#
# Filter checksum/hash file such as generated by md5sum for given set
# of paths.
#
# Copyright (C) 2022-2023 Elmar Hoffmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
#

import sys
import os
import io
import signal
import errno
import logging
import argparse


class FileIter:
    def __init__(self, file_):
        assert isinstance(file_, io.TextIOBase)
        self.file = file_

    def __iter__(self):
        return self

    def __next__(self):
        line = self.file.readline()
        if not line:
            raise StopIteration
        return line

def main():
    logging.basicConfig(format=os.path.basename(sys.argv[0])
                        + ": %(levelname)s: %(message)s")
    logger = logging.getLogger()

    parser = argparse.ArgumentParser(
        description="Filter checksum/hash file for given set of paths.",
        add_help=False)
    general_args = parser.add_argument_group("general arguments")

    general_args.add_argument("-?", "-h", "--help", action="help",
                              help="print this help and exit")
    general_args.add_argument("--usage", nargs=0, action=UsageAction,
                              help="print short usage and exit")
    general_args.add_argument("-V", "--version", action="version",
                              version="%(prog)s 2.0",
                              help="print version information and exit")

    parser.add_argument("-q", "--quiet", action="store_true",
                        help="suppress all normal output")
    parser.add_argument("--debug", action="store_true",
                        help="show general debugging output")

    parser.add_argument("-f", "--file", type=argparse.FileType('r'),
                        help="read paths to filter for from flie")
    parser.add_argument("-r", "--reverse", action="store_true",
                        help="handle checksum/hash file with path at beginning of line")
    parser.add_argument('files', metavar='FILE', nargs='*',
                        type=argparse.FileType('r'), default=[sys.stdin],
                        help='file to filter, \'-\' for standard input')

    try:
        args = parser.parse_args()
    except IOError as err:
        error('Error opening input file: {error}'.format(error=err))
        sys.exit(os.EX_NOINPUT)

    if args.debug:
        logger.setLevel(logging.DEBUG)

    filterset = set()
    if args.file:
        for line in FileIter(args.file):
            if line[-1] == '\n':
                line = line[:-1]
            filterset.add(line)
        logger.debug("Read %d paths from '%s'", len(filterset),
                     args.file.name)

    found = False
    for file_ in args.files:
        lineno = 1
        for line in FileIter(file_):
            line = line.rstrip('\n\r')

            try:
                sum_index = line.rindex('  ')
            except ValueError:
                error('Malformed input file {file}, line {lineno}:\n{line}'.format(
                    file=file_.name,
                    lineno=lineno,
                    line=line
                ))
                sys.exit(os.EX_DATAERR)

            if not args.reverse:
                sum_index += 2
                path = line[sum_index:]
            else:
                path = line[:sum_index]

            if path in filterset:
                if not args.quiet:
                    try:
                        print(line)
                    except BrokenPipeError:
                        exit_brokenpipe()
                else:
                    found = True

            lineno += 1

    if args.quiet and not found:
        sys.exit(1)

class UsageAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        parser.print_usage()
        parser.exit(os.EX_USAGE)

def exit_brokenpipe():
    # redirect output stream to /dev/null avoid additinal BrokenPipeError
    # in TextIOWrapper during cleanup
    devnull = os.open(os.devnull, os.O_WRONLY)
    os.dup2(devnull, sys.stdout.fileno())

    sys.exit(os.EX_IOERR)

def error(msg):
    sys.stderr.write("{prog}: {message}\n".format(
        prog=os.path.basename(sys.argv[0]),
        message=msg
    ))

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit(128 + signal.SIGINT)
    finally:
        try:
            sys.stdout.flush()
        except BrokenPipeError:
            exit_brokenpipe()

    sys.exit(os.EX_OK)
