import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Tuple, Union

from cotation.acoustic.struct.analysis.cotation_acoustic_analysis import (
	CotationAcousticAnalysis,
)
from cotation.acoustic.struct.praat.tiers import IntervalMarker
from reporting.indicator_reporting import IndicatorReporting
from tools.general_tools import GeneralTools
from tools.options import Options

logger = logging.getLogger(__name__)


class StepInputPersistence:
	"""
	Handles persistence operations for acoustic measurement steps.

	This class provides methods to save and retrieve step data and step inputs
	from the database, including interval markers, measurements, and other
	acoustic data collected during the evaluation process.
	"""

	@staticmethod
	def get_step_input_codes(step_code: str) -> Dict[str, int]:
		"""
		Retrieves all input codes associated with a specific step.

		Args:
			step_code: The unique code identifying the step

		Returns:
			A dictionary mapping input codes to their database IDs
		"""
		sql = """
			SELECT step_input.code, step_input.id
			FROM step INNER JOIN step_input
			ON step.id = step_input.step_id
			WHERE step.code = ?
		"""

		inputs = IndicatorReporting.participant_result_db.execute(
			sql, (step_code,)
		).fetchall()

		return {input["code"]: input["id"] for input in inputs}

	@staticmethod
	def get_saved_step_data(step: CotationAcousticAnalysis, judge_code: str) -> bool:
		"""
		Retrieves previously saved step data for a given step and judge.

		Fetches interval data and populates the provided step object with
		the saved interval information.

		Args:
			step: The acoustic measurement step object to populate
			judge_code: The code identifying the judge/evaluator

		Returns:
			True if saved data was found and loaded, False otherwise
		"""
		sql = """
			SELECT step_data.id as data_id, step_data.interval_start as start,
			step_data.interval_duration as dura, step_data.entry_date
			FROM step, step_data
			WHERE step.code = ?
			AND step.id = step_data.step_id
			AND step_data.participant = ?
			AND step_data.judge = ?
			AND step_data.session_date = ?
		"""

		t = (step.module_code, step.speaker_code, judge_code, step.session_date)

		step_result = IndicatorReporting.participant_result_db.execute(
			sql, t
		).fetchone()

		if step_result is None:
			return False

		step.set_interval(
			IntervalMarker.new_interval(
				step_result["start"], step_result["start"] + step_result["dura"]
			)
		)
		try:
			step.entry_date = datetime.strptime(
				step_result["entry_date"], "%Y-%m-%d %H:%M:%S"
			)
		except Exception:
			step.entry_date = None

		return True

	@staticmethod
	def get_saved_step_inputs(
		step_code: str, speaker_code: str, session_date: str, judge_code: str
	) -> List[Tuple[str, float, datetime]]:
		"""
		Retrieves all saved input values for a specific step, speaker, session and judge.

		Args:
			step_code: The code identifying the step
			speaker_code: The code identifying the speaker/participant
			session_date: The date of the session
			judge_code: The code identifying the judge/evaluator

		Returns:
			A list of tuples containing (input_code, input_value, entry_date)
		"""
		sql = """
			SELECT step_input.code, step_input_data.info_value, step_data.entry_date
			FROM step, step_input, step_data, step_input_data
			WHERE step_input.id = step_input_data.step_input_id
			AND step_input.step_id = step.id
			AND step_data.id = step_input_data.step_data_id
			AND step.code = ?
			AND step_data.participant = ?
			AND step_data.judge = ?
			AND step_data.session_date = ?
		"""

		t = (step_code, speaker_code, judge_code, session_date)

		return IndicatorReporting.participant_result_db.execute(sql, t).fetchall()

	def fill_from_saved_inputs(self, step: CotationAcousticAnalysis, judge_code: str):
		"""
		Populates a step object with previously saved input values.

		Retrieves all saved data for the given step and judge, and sets
		the corresponding attributes on the step object.

		Args:
			step: The acoustic measurement step object to populate
			judge_code: The code identifying the judge/evaluator
		"""
		has_saved_data: bool = self.get_saved_step_data(step, judge_code)

		if not has_saved_data:
			return

		saved_inputs: list[tuple[str, Union[float, int], datetime]] = (
			self.get_saved_step_inputs(
				step.module_code,
				step.speaker_code,
				step.session_date,
				judge_code,
			)
		)

		if saved_inputs is None:
			return

		for input in saved_inputs:
			if hasattr(step, input[0]):
				setattr(step, input[0], input[1])
			else:
				# Try with a setter
				setter_name = f"set_{input[0]}"
				try:
					setter: Callable[[Any], None] = getattr(step, setter_name)
					setter(input[1])
				except AttributeError as e:
					logger.exception(
						f"No setter {setter_name} on {step.__class__.__name__}: {e}"
					)

	@staticmethod
	def save_step_data(step: CotationAcousticAnalysis, judge_code: str) -> int:
		"""
		Saves or updates the base step data in the database.

		Persists information about the interval, session, participant, and judge
		for a measurement step.

		Args:
			step: The acoustic measurement step to save
			judge_code: The code identifying the judge/evaluator

		Returns:
			The database ID of the inserted or updated step data

		Note:
			There's a known issue with the database operation as DELETE+INSERT
			operations conflict with foreign key constraints.
		"""
		# TODO pb de BD car REPLACE FAIT UN DELETE+INSERT -> foreign key
		# constraint, impossible de supprimer

		sql = """
			INSERT INTO step_data
			(step_id, session_date, participant, judge, version, interval_start, interval_duration)
			SELECT step.id, ?, ?, ?, ?, ?, ?
			FROM step
			WHERE step.code = ?
			ON CONFLICT(step_id, session_date, participant, judge)
			DO UPDATE SET
			version = ?, interval_start = ?,
			interval_duration = ?,
			entry_date = CURRENT_TIMESTAMP
			RETURNING step_data.id as step_data_id
		"""

		t = (
			step.session_date,
			step.speaker_code,
			judge_code,
			GeneralTools.get_version(),
			float(step.get_interval().start_time),
			step.get_interval_duration(),
			step.module_code,
			GeneralTools.get_version(),
			float(step.get_interval().start_time),
			step.get_interval_duration(),
		)

		if Options.is_enabled(Options.Option.EXPORT_TEXTGRID):
			GeneralTools.create_text_grid(
				"PatternTextGrid.TextGrid",
				judge_code,
				float(step.get_interval().start_time),
				step.get_interval_duration(),
				step,
			)

		step.entry_date = datetime.now()

		return IndicatorReporting.participant_result_db.execute(sql, t).fetchone()[
			"step_data_id"
		]

	@staticmethod
	def save_step_input_data(
		input_code: str,
		input_id: int,
		data_id: int,
		step: CotationAcousticAnalysis,
	):
		"""
		Saves a single input value for a step.

		Args:
			input_code: The code identifying the input field
			input_id: The database ID of the input field
			data_id: The database ID of the parent step data
			step: The acoustic measurement step containing the input value
		"""
		sql = """
			INSERT INTO step_input_data
			(step_input_id, step_data_id, info_value, version)
			VALUES (?, ?, ?, ?)
			ON CONFLICT(step_input_id, step_data_id) DO
			UPDATE SET version = ?, info_value = ?
		"""

		try:
			value: Any = getattr(step, input_code)
		except AttributeError:
			# Try with getter
			try:
				logger.info(f"Trying to access {input_code} via a getter")

				attribute: Callable[[], Any] = getattr(step, f"get_{input_code}")
				value: Any = attribute()
			except AttributeError as e:
				logger.error(
					f"No getter like 'get_{input_code}' in {step.__class__.__name__}: {e}"
				)

		t = (
			input_id,
			data_id,
			value,
			GeneralTools.get_version(),
			GeneralTools.get_version(),
			value,
		)

		IndicatorReporting.participant_result_db.execute(sql, t)

	def save_all_step_input_data(self, step: CotationAcousticAnalysis, judge_code: str):
		"""
		Saves all input data for a step in a single transaction.

		This method handles the full persistence process, saving both the
		base step data and all associated input values.

		Args:
			step: The acoustic measurement step to save completely
			judge_code: The code identifying the judge/evaluator
		"""
		data_id: int = self.save_step_data(step, judge_code)

		step_input_codes: dict[str, int] = self.get_step_input_codes(step.module_code)

		for input_code, input_id in step_input_codes.items():
			self.save_step_input_data(
				input_code,
				input_id,
				data_id,
				step,
			)

		self.commit()

	@staticmethod
	def commit():
		"""
		Commits pending database changes.

		Finalizes all database operations that were performed since the last commit.
		This ensures data is permanently stored in the database.
		"""
		IndicatorReporting.participant_result_db.commit()
