from typing import Any, Generic, Optional, TypeVar, cast

try:
	from typing import override
except ImportError:
	from typing_extensions import override  # noqa: F401

from abc import ABC, abstractmethod
from enum import Enum

import pyqtgraph as pg  # type: ignore
from bidict import bidict
from PySide6 import QtGui
from PySide6.QtCore import Qt, Signal
from PySide6.QtWidgets import QVBoxLayout, QWidget

from .markers import IntervalMarker, IntervalMarkerList, Marker, MarkerList

T = TypeVar("T")
U = TypeVar("U")


class TextgridConverter(ABC, Generic[T, U]):
	@abstractmethod
	def to_textgrid(self, to_convert: T) -> U:
		"""
		Converts an object of type T to an object of type U.

		Args:
			to_convert (T): The object to convert.

		Returns:
			U: The converted object.
		"""
		pass

	@abstractmethod
	def from_textgrid(self, textgrid: U) -> T:
		"""
		Converts an object of type U back to an object of type T.

		Args:
			textgrid (U): The object to convert from.

		Returns:
			T: The converted object.
		"""
		pass


class TierType(Enum):
	INTERVAL_TIER = (0,)
	POINT_TIER = 1


THEME_PEN = pg.mkPen("b", width=2)

TierElement = TypeVar("TierElement", Marker, IntervalMarker)
RemoveTierElement = TypeVar("RemoveTierElement", Marker, IntervalMarker)


class Tier(pg.PlotWidget, Generic[TierElement, RemoveTierElement]):
	name: str
	__start_time: float
	__end_time: float
	__tier_type: TierType
	__converter: TextgridConverter["Tier", Any]

	ELEMENT_POSITION_CHANGED = Signal(float, float)

	def __init__(
		self,
		name: str,
		tier_type: TierType,
		start_time: float,
		end_time: float,
		converter: TextgridConverter["Tier", Any],
	):
		super().__init__()

		self.name = name
		self.__start_time = start_time
		self.__end_time = end_time

		self.__converter = converter

		self.getAxis("left").setStyle(showValues=False, tickAlpha=0, tickTextWidth=60)

		self.setMouseEnabled(y=False)
		self.setYRange(0, 1)
		self.setFixedHeight(100)

		self.setXRange(self.__start_time, self.__end_time)
		self.setLabel("bottom", "Temps", units="s")

	def __eq__(self, other: object) -> bool:
		"""
		Check if two Tier objects are equal.

		Args:
			other (object): The object to compare with.

		Returns:
			bool: True if both objects are the same instance, False otherwise.
		"""
		if not isinstance(other, Tier):
			return NotImplemented
		return self is other

	def __repr__(self) -> str:
		"""
		Return a string representation of the Tier object, showing its name and time limits.

		Returns:
			str: String describing the tier.
		"""
		return f"'name : {self.name}, limits: {self.__start_time} - {self.__end_time}'"

	@abstractmethod
	def add_element(self, element: TierElement, **kargs: Any):
		"""
		Add an element to the tier.

		Args:
			element (TierElement): The element to add.
			**kargs (Any): Additional arguments.

		Raises:
			NotImplementedError: If not implemented in subclass.
		"""
		pass

	@abstractmethod
	def remove_element(self, element: RemoveTierElement):
		"""
		Remove a specified element from the tier.

		Args:
			element (RemoveTierElement): The element to remove.

		Raises:
			NotImplementedError: If not implemented in subclass.
		"""
		pass

	@abstractmethod
	def remove_element_by_idx(self, index: int):
		"""
		Remove an element from the tier by its index.

		Args:
			index (int): The index of the element to remove.

		Raises:
			NotImplementedError: If not implemented in subclass.
		"""

		pass

	@abstractmethod
	def get_element(self, index: int) -> TierElement:
		"""
		Retrieve an element from the tier by its index.

		Args:
			index (int): The index of the element to retrieve.

		Returns:
			TierElement: The element at the given index.

		Raises:
			NotImplementedError: If not implemented in subclass.
		"""
		pass

	@abstractmethod
	def get_elements(self) -> list[TierElement]:
		"""
		Return all elements contained in the tier.

		Returns:
			list[TierElement]: List of elements within the tier.

		Raises:
			NotImplementedError: If not implemented in subclass.
		"""
		pass

	@abstractmethod
	def change_element_position(self, marker: Marker, new_value: float):
		"""
		Change the position of a specified element (marker) within the tier.

		Args:
			marker (Marker): The element whose position will be changed.
			new_value (float): The new position value.

		Raises:
			NotImplementedError: If not implemented in subclass.
		"""
		pass

	def to_textgrid(self) -> Any:
		"""
		Convert the tier to a TextGrid-compatible format using the associated converter.

		Returns:
			Any: TextGrid representation of the tier.
		"""
		return self.__converter.to_textgrid(self)

	def get_start_time(self) -> float:
		"""
		Get the start time of the tier.

		Returns:
			float: Start time value.
		"""
		return self.__start_time

	def get_end_time(self) -> float:
		"""
		Get the end time of the tier.

		Returns:
			float: End time value.
		"""
		return self.__end_time


class PointTier(Tier[Marker, Marker]):
	mlist: MarkerList
	marker_to_display: bidict[pg.InfiniteLine, Marker]
	hovered_line: Optional[pg.InfiniteLine]

	def __init__(
		self,
		name: str,
		start_time: float,
		end_time: float,
		converter: TextgridConverter["PointTier", Any],
	):
		super().__init__(name, TierType.POINT_TIER, start_time, end_time, converter)

		self.mlist = MarkerList()
		self.marker_to_display = bidict()
		self.hovered_line = None

		self.scene().sigMouseHover.connect(self.mouse_moved)

	def mouse_moved(self, hover_items: list[pg.PlotItem]):
		"""
		Handle mouse movement events by identifying if the mouse is hovering over an InfiniteLine.

		Args:
			hover_items (list[pg.PlotItem]): List of plot items currently hovered over.

		Sets:
			self.hovered_line (pg.InfiniteLine | None): The InfiniteLine currently hovered, or None if none.
		"""
		self.hovered_line = next(
			(el for el in hover_items if type(el) == pg.InfiniteLine), None
		)

	@override
	def add_element(self, element: Marker, **kargs: Any):
		"""
		Add a new marker element to the list and its visual representation to the plot.

		Args:
			element (Marker): The marker to be added.
			**kargs (Any): Additional keyword arguments.

		Behavior:
			- Prevents adding duplicate elements.
			- Creates a movable InfiniteLine corresponding to the marker.
			- Connects position changes of the line to update the marker position.
		"""
		if element in self.mlist:
			return

		element = self.mlist.add_marker(element)
		element_line = pg.InfiniteLine(
			pos=element.position,
			label=element.name,
			labelOpts={"color": (0, 0, 0)},
			pen=THEME_PEN,
			movable=True,
		)

		self.addItem(element_line)
		self.marker_to_display[element_line] = element

		element_line.sigPositionChangeFinished.connect(
			lambda l: self.change_element_position(self.marker_to_display[l], l.value())
		)

	@override
	def remove_element_by_idx(self, index: int):
		"""
		Remove a marker element by its index from the list and update the plot.

		Args:
			index (int): The index of the marker to remove.
		"""
		removed_marker = self.mlist.remove_marker_by_idx(index)
		self.remove_element(removed_marker)

	@override
	def remove_element(self, element: Marker):
		"""
		Remove a specific marker element and its visual representation from the plot.

		Args:
			element (Marker): The marker to remove.
		"""
		marker_line = self.marker_to_display.inverse[element]
		self.removeItem(marker_line)

	@override
	def get_element(self, index: int) -> Marker:
		"""
		Retrieve a marker element by index.

		Args:
			index (int): Index of the marker to retrieve.

		Returns:
			Marker: The marker at the specified index.
		"""
		return self.mlist.get_marker(index)

	@override
	def get_elements(self) -> list[Marker]:
		"""
		Return all marker elements in the list.

		Returns:
			list[Marker]: List of markers currently managed.
		"""
		return self.mlist.get_markers()

	@override
	def change_element_position(self, marker: Marker, new_value: float):
		"""
		Change the position of a given marker and emit an event signaling the change.

		Args:
			marker (Marker): The marker whose position is to be updated.
			new_value (float): The new position value.

		Behavior:
			- Updates the marker's position.
			- Notifies the marker list about the change.
			- Emits an event with previous and new position values.
		"""
		previous_value = marker.position
		marker.position = new_value
		self.mlist.notify_marker_changed()
		self.ELEMENT_POSITION_CHANGED.emit(previous_value, new_value)

	@override
	def keyPressEvent(self, event: QtGui.QKeyEvent):
		"""
		Handle key press events to edit the label of the currently hovered line.

		Args:
			event (QtGui.QKeyEvent): The key press event.

		Behavior:
			- Calls the superclass's keyPressEvent handler.
			- If no line is hovered, does nothing.
			- Ignores unknown keys.
			- If Backspace is pressed, removes the last character from the label.
			- For other keys, appends the character to the label.
			- Updates the corresponding marker's name with the new label text.
		"""
		super().keyPressEvent(event)

		if self.hovered_line is None:
			return

		line = self.hovered_line

		old_text = line.label.toPlainText()

		if event.key() == Qt.Key.Key_unknown:
			return

		if event.key() == Qt.Key.Key_Backspace:
			line.label.setFormat(old_text[:-1])
		elif event.key():
			line.label.setFormat(old_text + event.text())

		self.marker_to_display[line].name = line.label.toPlainText()


class IntervalTier(Tier[IntervalMarker, Marker]):
	mlist: IntervalMarkerList
	marker_to_display: bidict[Marker, pg.InfiniteLine]
	marker_label: dict[Marker, pg.TextItem]
	last_mouse_position = None

	def __init__(
		self,
		name: str,
		start_time: float,
		end_time: float,
		converter: TextgridConverter["IntervalTier", Any],
	):
		super().__init__(name, TierType.POINT_TIER, start_time, end_time, converter)

		self.mlist = IntervalMarkerList()
		self.marker_to_display = bidict()
		self.marker_label = {}
		self.last_mouse_position = None

		self.add_element(
			IntervalMarker.new_interval(start_time, end_time), movable=False
		)

		self.scene().sigMouseMoved.connect(self.mouse_moved)

	def mouse_moved(self, position):
		"""
		Updates the stored last mouse position.

		Args:
			position: The current position of the mouse cursor.
		"""
		self.last_mouse_position = position

	@override
	def keyPressEvent(self, event: QtGui.QKeyEvent):
		"""
		Handles key press events for editing marker labels based on the current mouse position.

		- Checks if the last mouse position is valid and inside the plot area.
		- Finds the closest marker to the left of the mouse position.
		- Updates the corresponding text label based on the key pressed:
		- Backspace removes the last character.
		- Other keys append the typed character.
		- Updates the marker's name to match the edited label.

		Args:
			event (QtGui.QKeyEvent): The key press event.
		"""
		super().keyPressEvent(event)

		if self.last_mouse_position is None:
			return

		if not self.plotItem.vb.sceneBoundingRect().contains(self.last_mouse_position):
			return

		my = self.plotItem.vb.mapSceneToView(self.last_mouse_position).x()

		last_smaller = max(
			(m for m in self.mlist.get_markers() if m.position <= my), default=None
		)

		if last_smaller is None:
			raise ValueError()

		text_label = self.marker_label[last_smaller]

		old_text = text_label.toPlainText()

		if event.key() == Qt.Key.Key_unknown:
			return

		if event.key() == Qt.Key.Key_Backspace:
			text_label.setPlainText(old_text[:-1])
		elif event.key():
			text_label.setPlainText(old_text + event.text())

		last_smaller.name = text_label.toPlainText()

	# def __change_pos(self, line):
	#     m1 = self.marker_to_display.inverse[line]
	#     m1.position = line.value()
	#     self.mlist.notify_marker_changed()
	#     self.ELEMENT_POSITION_CHANGED.emit(previous_value, new_value)

	def __create_line(self, marker: Marker, movable: bool = True) -> pg.InfiniteLine:
		"""
		Creates or retrieves a movable vertical line representing a marker on the plot.

		- Returns an existing line if the marker is already displayed.
		- Checks if a line at the same position already exists and returns it if found.
		- Otherwise, creates a new movable InfiniteLine at the marker's position.
		- Connects the line's position change signal to update the marker's position.

		Args:
			marker (Marker): The marker to create a line for.
			movable (bool): Whether the line should be movable by the user. Defaults to True.

		Returns:
			pg.InfiniteLine: The created or retrieved line corresponding to the marker.
		"""
		if marker in self.marker_to_display:
			return self.marker_to_display[marker]

		same_pos = [
			marker.compare_position(l.value()) for l in self.marker_to_display.inverse
		]

		if len(same_pos) > 0:
			return same_pos[0]

		element_line = pg.InfiniteLine(
			pos=marker.position, pen=THEME_PEN, movable=movable
		)

		self.addItem(element_line)
		self.marker_to_display[marker] = element_line

		element_line.sigPositionChanged.connect(
			lambda l: self.change_element_position(
				self.marker_to_display.inverse[l], l.value()
			)
		)

		return element_line

	def __create_label(self, marker: Marker) -> pg.TextItem:
		"""
		Creates or updates a text label for the given marker on the plot.

		- If a label already exists for the marker, updates its text.
		- Skips label creation if the marker is the last in the list (no next marker).
		- Creates a new TextItem with the marker's name.
		- Sets font and color for the label.
		- Adds the label to the plot and configures its position change behavior.

		Args:
			marker (Marker): The marker to create a label for.

		Returns:
			pg.TextItem | None: The created or updated text label, or None if not created.
		"""
		if marker in self.marker_label:
			text_item = self.marker_label[marker]
			text_item.setPlainText(marker.name)
			return

		marker_idx = self.mlist.get_marker_idx(marker)

		if marker_idx >= len(self.mlist.get_markers()) - 1:
			return

		text_item = pg.TextItem(text=marker.name, color=(0, 0, 0), anchor=(0.5, 1))
		text_item.setFont(pg.QtGui.QFont("Arial", 14))

		self.addItem(text_item)
		self.marker_label[marker] = text_item

		self.__config_text_pos_change(marker)

		return text_item

	def __config_text_pos_change(self, marker: Marker):
		"""
		Configures the position of the text label associated with a marker to stay centered
		between the marker's line and its neighboring marker's line.

		- Retrieves the InfiniteLine for the marker and the next marker.
		- Sets the text label position to the midpoint between these two lines.
		- Connects the position change signals of both lines to update the label position dynamically.

		Args:
			marker (Marker): The marker whose label position needs to be configured.
		"""
		line = self.marker_to_display[marker]
		marker_idx = self.mlist.get_marker_idx(marker)

		text_item = self.marker_label[marker]

		neighboor = self.mlist.get_marker(marker_idx + 1)
		nline = self.marker_to_display[neighboor]

		text_item.setPos((line.value() + nline.value()) / 2, 0.5)
		line.sigPositionChanged.connect(
			lambda l: text_item.setPos((line.value() + nline.value()) / 2, 0.5)
		)
		nline.sigPositionChanged.connect(
			lambda l: text_item.setPos((line.value() + nline.value()) / 2, 0.5)
		)

	def __config_event_listeners(self):
		"""
		Sets up event listeners for all marker labels to keep their positions updated.

		- Iterates through all marker-label pairs.
		- Ensures each label's position is configured relative to its marker's line and neighbor.

		This method ensures labels remain correctly positioned when markers move.
		"""
		for m, text_item in self.marker_label.items():
			self.__config_text_pos_change(m)

	@override
	def add_element(self, element: IntervalMarker, movable: bool = True, **kargs: Any):
		"""
		Adds an interval element to the view and sets up its visual representation.

		- Registers the interval in the marker list.
		- Creates and displays lines for the start and end markers.
		- Creates labels for both markers.
		- Sets up event listeners to keep the labels dynamically positioned.

		Args:
			element (IntervalMarker): The interval to add.
			movable (bool): Whether the lines representing the interval should be movable.
			**kargs: Additional keyword arguments (not used here).
		"""
		self.mlist.add_interval(element)

		self.__create_line(element.start_time, movable)
		self.__create_line(element.end_time, movable)

		self.__create_label(element.start_time)
		self.__create_label(element.end_time)

		self.__config_event_listeners()

	@override
	def remove_element_by_idx(self, index: int):
		"""
		Removes an element (interval marker) by its index.

		- Removes the marker from the internal marker list.
		- Delegates to `remove_element` to update the display accordingly.

		Args:
			index (int): Index of the marker to remove.
		"""
		removed_marker = self.mlist.remove_marker_by_idx(index)
		self.remove_element(removed_marker)

	@override
	def remove_element(self, element: Marker):
		"""
		Removes a specific marker from the display.

		- Removes the corresponding line from the plot.
		- Removes the mapping from internal dictionaries.

		Args:
			element (Marker): The marker to remove.
		"""
		marker_line = self.marker_to_display.pop(element)
		self.removeItem(marker_line)

	@override
	def get_element(self, index: int) -> IntervalMarker:
		"""
		Returns the interval at the specified index.

		- Uses modular indexing to ensure the index is within bounds.

		Args:
			index (int): Index of the desired interval.

		Returns:
			IntervalMarker: The interval at the given index.
		"""
		return self.mlist.get_interval(index)

	@override
	def get_elements(self) -> list[IntervalMarker]:
		"""
		Returns all the intervals currently managed by the view.

		Returns:
			list[IntervalMarker]: A list of all interval markers.
		"""
		return self.mlist.get_intervals()

	@override
	def change_element_position(self, marker: Marker, new_value: float):
		"""
		Update the position of a marker while ensuring it stays within valid bounds.

		If the new position would overlap with the previous or next marker (within a small buffer),
		it adjusts the marker's position to maintain a minimum interval duration.

		Parameters:
			marker (Marker): The marker whose position is to be changed.
			new_value (float): The desired new position for the marker.

		Emits:
			ELEMENT_POSITION_CHANGED (float, float): Emitted with the old and new marker positions
			when a valid position change occurs.

		Notes:
			- If the new position would violate ordering with adjacent markers,
			it is automatically adjusted to maintain a minimum gap of 0.005.
		"""

		marker_idx = self.mlist.get_marker_idx(marker)
		next_marker = self.mlist.get_marker(marker_idx + 1)
		previous_marker = self.mlist.get_marker(marker_idx - 1)

		min_interval_duration = 0.005
		if new_value >= next_marker.position:
			self.marker_to_display[marker].setValue(
				next_marker.position - min_interval_duration
			)
			return

		if new_value <= previous_marker.position:
			self.marker_to_display[marker].setValue(
				previous_marker.position + min_interval_duration
			)
			return

		previous_value = marker.position
		marker.position = new_value
		self.mlist.notify_marker_changed()
		self.ELEMENT_POSITION_CHANGED.emit(previous_value, new_value)


class TextGrid(QWidget):
	linked_plot: pg.PlotWidget
	__internal_vb: pg.ViewBox
	__converter: TextgridConverter["TextGrid", Any]
	tiers: list[Tier[Any, Any]]

	def __init__(
		self, linked_plot: pg.PlotWidget, converter: TextgridConverter["TextGrid", Any]
	):
		super().__init__()

		self.tiers = []
		self.linked_plot = linked_plot
		self.__internal_vb = linked_plot
		self.__converter = converter

		layout = QVBoxLayout()
		layout.setContentsMargins(0, 0, 0, 0)
		layout.setSpacing(0)
		self.setLayout(layout)

	def __link_views(self):
		"""
		Link the view ranges and X-axis limits of all tiers to the internal viewbox.

		This ensures all tier views are synchronized in terms of horizontal zoom and pan.
		"""

		(xmin, xmax), (ymin, ymax) = self.__internal_vb.viewRange()
		for t in self.tiers:
			t.setXLink(self.__internal_vb)
			t.setLimits(xMin=xmin, xMax=xmax)

	# insert at the end of all negative values
	def add_tier(self, new_tier: Tier[Any, Any], tier_index: int = -1):
		"""
		Add a new tier to the layout and internal tier list at the specified index.

		Parameters:
			new_tier (Tier[Any, Any]): The tier widget to add.
			tier_index (int, optional): Index at which to insert the tier. If negative,
										the tier is added at the end. Defaults to -1.

		Raises:
			ValueError: If the layout is missing or the tier_index is invalid.

		Notes:
			- A negative index results in appending the tier to the end.
			- After insertion, the new tier is linked to the internal viewbox
			for consistent horizontal view synchronization.
		"""

		layout = cast(Optional[QVBoxLayout], self.layout())

		if layout is None:
			raise ValueError("Textgrid has no layou")

		nb_tiers = layout.count()
		if tier_index >= nb_tiers:
			msg = f"Invalid tier_index {tier_index} for nb tiers: {nb_tiers}."
			raise ValueError(msg)

		# Inserts at the end if negative (see docs)
		layout.insertWidget(tier_index, new_tier)

		if tier_index < 0:
			tier_index = nb_tiers

		# -1 will insert before the last element
		self.tiers.insert(tier_index, new_tier)

		self.__link_views()

	def remove_tier_by_idx(self, tier_index: int):
		"""
		Remove a tier from the layout and internal list by its index.

		Parameters:
			tier_index (int): Index of the tier to remove.

		Raises:
			ValueError: If the index is out of bounds, the layout is missing,
						or the widget cannot be found or removed properly.

		Notes:
			- Supports negative indices via modulo wrapping.
			- Frees associated widget resources and relinks remaining views.
		"""

		if tier_index >= len(self.tiers):
			msg = f"Invalid tier index {tier_index} for nb tiers: {len(self.tiers)}."
			raise ValueError(msg)

		tier_index = tier_index % len(self.tiers)

		layout = self.layout()

		if layout is None:
			raise ValueError("Textgrid has no layou")

		self.tiers.pop(tier_index)
		item = layout.takeAt(tier_index)

		if item is None:
			raise ValueError("Item not in layout")

		widget = item.widget()

		if widget is None:
			raise ValueError("Error when fetching item widget")

		widget.deleteLater()

		self.__link_views()

	def get_tiers(self) -> list[Tier[Any, Any]]:
		"""
		Return a copy of the list of all current tiers.

		Returns:
			list[Tier[Any, Any]]: A shallow copy of the internal tier list.
		"""

		return self.tiers.copy()

	def get_tiers_by_name(self, tier_name: str) -> list[Tier[Any, Any]]:
		"""
		Retrieve all tiers that match the given name.

		Parameters:
			tier_name (str): The name to search for among the tiers.

		Returns:
			list[Tier[Any, Any]]: A list of tiers matching the provided name.

		Raises:
			ValueError: If the provided tier_name is empty.
		"""

		if not tier_name:
			raise ValueError("The given tier_name was empty.")

		return [t for t in self.tiers if t.get_name() == tier_name]

	def get_tier_by_index(self, tier_index: int) -> Tier[Any, Any]:
		"""
		Retrieve a tier by its index from the internal list.

		Parameters:
			tier_index (int): Index of the tier to retrieve. Supports negative indices.

		Returns:
			Tier[Any, Any]: The tier at the specified index.

		Raises:
			ValueError: If the index is out of bounds.
		"""

		if tier_index >= len(self.tiers) or abs(tier_index) - 1 >= len(self.tiers):
			msg = f"Invalid tier index {tier_index} for nb tiers: {len(self.tiers)}."
			raise ValueError(msg)

		return self.tiers[tier_index]

	def get_tier_index(self, tier: Tier[Any, Any]) -> Optional[int]:
		"""
		Get the index of the specified tier in the internal tier list.

		Parameters:
			tier (Tier[Any, Any]): The tier to look for.

		Returns:
			Optional[int]: The index of the tier if found; otherwise, None.
		"""

		for i, t in enumerate(self.get_tiers()):
			if t != tier:
				continue
			return i

		return None

	def to_textgrid(self) -> Any:
		"""
		Convert the current internal state into a standard TextGrid format.

		Returns:
			Any: A representation of the current object in TextGrid format,
				as defined by the internal converter.
		"""

		return self.__converter.to_textgrid(self)
