# SPDX-FileCopyrightText: Copyright (C) 2024-2025 Linaro Limited (or its affiliates). All rights reserved.
#
# -*- coding: iso-8859-1 -*-

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function
import itertools
import gdb
import json
import contextlib
import re

class ForgeThreadsCrossExpressions(gdb.MICommand):
    """
    -forge-threads-cross-expressions {"expressions":[{"expression":"<expression>", "language":"fortran"}],
                                      "fid":"<framework thread identifier expression>",
                                      "frames":{"<gid>":<frame>,},
                                      "gids":["1-4"],"inferiors":["1"]}
    """
    def __init__(self):
        super (ForgeThreadsCrossExpressions, self).__init__ ("-forge-threads-cross-expressions")

    def invoke(self, args):
        """
        Run the command.
        """

        if hasattr(gdb, 'cuda'):
            if not (hasattr(gdb.cuda, 'get_focus_logical') and hasattr(gdb.cuda, 'set_focus_logical')):
                raise gdb.GdbError(self.name + " requires methods get/set_focus_logical when cuda module detected.")

        result_key = "result"
        data = {result_key : []}
        if not args:
            return data

        parsed_arguments = json.loads(" ".join(args))
        gids = parse_range_arg(parsed_arguments, "gids")
        inferiors = parse_range_arg(parsed_arguments, "inferiors")
        framework_expression = parsed_arguments.get("fid", None)
        expressions = parsed_arguments["expressions"]

        physicalCords = None
        logicalCords  = None
        if hasattr(gdb, 'cuda'):
            physicalCords = gdb.cuda.get_focus_physical()
            logicalCords  = gdb.cuda.get_focus_logical()

        with restore_thread(gdb.selected_thread(), physicalCords, logicalCords):
            for thread in global_threads_filter_switch(gids, inferiors, parsed_arguments["frames"]):
                fid = gdb.parse_and_eval(framework_expression) if framework_expression else None

                for expression in expressions:
                    language = expression["language"] if "language" in expression else None
                    expression_data = {"expression" : expression["expression"], "gid" : thread.global_num}
                    expression_data = set_optional_field(expression_data, "fid", framework_expression, fid)
                    expression_data = set_optional_field(expression_data, "language", language, language)

                    try:
                        if language:
                            with gdb.with_parameter('language', language.lower()):
                                expression_data["value"] = gdb.parse_and_eval(expression["expression"])
                        else:
                            expression_data["value"] = gdb.parse_and_eval(expression["expression"])

                    except gdb.error as error:
                        expression_data["error"] = str(error)
                    data[result_key].append(expression_data)
        data[result_key].sort(key=lambda x: x["gid"], reverse=True)
        return data

def set_optional_field(data, field, exists, value):
    if exists:
        data[field] = value
    return data

def parse_range_arg(args, flag):
    """
    Extracts a thread list argument from the input. Defaults to
    "all".
    """
    if flag not in args:
        return ["all"]
    return expand_range(args[flag])

class restore_thread:
    """
    Context manager for undoing a thread switch.
    """
    def __init__(self, thread, cudaPhyiscalCoords, cudaLogicalCoords):
        self.thread = thread
        self.cudaPhyiscalCoords = cudaPhyiscalCoords
        self.cudaLogicalCoords  = cudaLogicalCoords

    def __enter__(self):
        return self.thread

    def __exit__(self, *args):
        self.thread.switch()
        if hasattr(gdb, 'cuda'):
            if self.cudaPhyiscalCoords:
                gdb.cuda.set_focus_physical(self.cudaPhyiscalCoords)
            if self.cudaLogicalCoords:
                gdb.cuda.set_focus_logical(self.cudaLogicalCoords)

def expand_range(ranges):
    """
    Expands a GDB thread list into a set of all threads covered by the list.
    This is then used to check for the presence of a thread in the thread list.
    """
    expanded_ranges = []
    for r in ranges:
        if r.isdigit():
            expanded_ranges.append(int(r))
        elif r == "all":
            expanded_ranges.append(r)
        else:
            subrange = [int(x) for x in r.split("-")]
            subrange[-1] += 1
            expanded_ranges += [x for x in range(*tuple(subrange))]
    return set(expanded_ranges)

def global_threads_filter_switch(gids, inferiors, frames):
    """
    Generator over all global threads specified by the filter. Also switch to the thread and frame.
    """
    for thread in global_threads_filter(gids, inferiors):
        thread.switch()
        frame_level = gdb.selected_frame().level()
        if str(thread.global_num) in frames:
            scope_level = frames[str(thread.global_num)]
            if frame_level != scope_level:
                if frame_level < scope_level:
                    while gdb.selected_frame ().level() < scope_level:
                        older = gdb.selected_frame ().older()
                        if not older:
                            break
                        older.select()
                else:
                     while gdb.selected_frame ().level() > scope_level:
                        newer = gdb.selected_frame ().newer()
                        if not newer:
                            break
                        newer.select()
        yield thread

def global_threads_filter(gids, inferiors):
    """
    Generator over all global threads specified by the filter.
    """
    print_all_gid = "all" in gids
    print_all_inferior = "all" in inferiors
    select = lambda thread: (print_all_gid or thread.global_num in gids) and (print_all_inferior or thread.inferior.num in inferiors)
    for thread in filter(select, global_threads()):
        yield thread

def global_threads():
    """
    Generator over all global threads.
    """
    for inferior in gdb.inferiors():
        for thread in inferior.threads():
            yield thread

ForgeThreadsCrossExpressions()
