# encoding=utf-8
"""
Class used to encapsulate some calculations in parselmouth
"""

__author__ = "Aaron Randreth"
__copyright__ = "Copyright 2015+, Consortium MonPaGe"
__license__ = "Creative Commons 4.0 By-Nc-Sa"
__maintainer__ = "Roland Trouville"
__email__ = "contact.monpage@gmail.com"
__status__ = "Production"

from typing import Optional

import numpy
import parselmouth  # type: ignore

PRAAT_FULL_FILE: float = 0.0
PITCH_UNIT: str = "Hertz"
MINIMUM_INTERPOLATION_TYPE: str = "Parabolic"


# In speech analysis, Pitch = F0 = Fondamental frequency
class ParselmouthPitch:
	"""
	A wrapper class for analyzing pitch-related acoustic features from a Parselmouth Sound and Pitch object
	within a specified time interval.

	Attributes:
	    pitch (parselmouth.Pitch): The Parselmouth Pitch object.
	    start_time (float): Start time (in seconds) of the analysis interval.
	    end_time (float): End time (in seconds) of the analysis interval.
	    sound (parselmouth.Sound): The Parselmouth Sound object.
	    point_process (parselmouth.PointProcess): Point process derived from the sound and pitch, used for jitter/shimmer calculations.
	    HIGH_VOICE_BREAKS_FLOOR (int): Threshold above which voice breaks are considered high.
	    UNUSABLE_VOICE_BREAKS_FLOOR (int): Threshold above which voice breaks are considered unusable.
	"""

	pitch: parselmouth.Pitch
	start_time: float
	end_time: float

	sound: parselmouth.Sound
	point_process: "parselmouth.PointProcess"

	HIGH_VOICE_BREAKS_FLOOR = 15
	UNUSABLE_VOICE_BREAKS_FLOOR = 25

	def __init__(
		self,
		sound,
		pitch: parselmouth.Pitch,
		start_time: float = 0,
		end_time: float = PRAAT_FULL_FILE,
	):
		self.sound = sound
		self.pitch = pitch

		self.point_process = parselmouth.praat.call(
			[sound, pitch], "To PointProcess (cc)"
		)

		self.start_time = start_time
		self.end_time = end_time

	def get_plot_values(self) -> tuple[numpy.ndarray, numpy.ndarray]:
		"""
		Get filtered pitch time and frequency values for plotting.

		Only includes points within the time interval [start_time, end_time] and frequency range (0, 1000).
		Points outside frequency range are replaced with NaN.

		Returns:
			tuple[numpy.ndarray, numpy.ndarray]: Two arrays containing filtered times and frequencies.
		"""
		times = self.pitch.xs()
		frequencies = self.pitch.selected_array["frequency"]

		filtered_times = []
		filtered_frequencies = []

		floor, ceil = 0, 1000

		for time, frequency in zip(times, frequencies):
			in_interval = self.start_time <= time and time <= self.end_time

			if not in_interval:
				continue

			in_range = floor < frequency and frequency < ceil

			if not in_range:
				filtered_times.append(float("NaN"))
				filtered_frequencies.append(float("NaN"))
				continue

			filtered_times.append(time)
			filtered_frequencies.append(frequency)

		return filtered_times, filtered_frequencies

	def get_minimum(self) -> float:
		"""
		Get the minimum pitch value within the time range using specified unit and interpolation.

		Returns:
			float: Minimum pitch value.
		"""
		return parselmouth.praat.call(
			self.pitch,
			"Get minimum",
			self.start_time,
			self.end_time,
			PITCH_UNIT,
			MINIMUM_INTERPOLATION_TYPE,
		)

	def get_maximum(self) -> float:
		"""
		Get the maximum pitch value within the time range using specified unit and interpolation.

		Returns:
			float: Maximum pitch value.
		"""
		return parselmouth.praat.call(
			self.pitch,
			"Get maximum",
			self.start_time,
			self.end_time,
			PITCH_UNIT,
			MINIMUM_INTERPOLATION_TYPE,
		)

	def get_quantile(
		self,
		quantile: float,
		start_time: Optional[float] = None,
		end_time: Optional[float] = None,
	) -> float:
		"""
		Get the pitch quantile value within a specified time interval.

		Args:
			quantile (float): Quantile value between 0 and 1 (e.g., 0.05 for 5th percentile).
			start_time (Optional[float]): Start time of interval. Defaults to object's start_time.
			end_time (Optional[float]): End time of interval. Defaults to object's end_time.

		Returns:
			float: Pitch value at the specified quantile.
		"""
		if start_time is None:
			start_time = self.start_time

		if end_time is None:
			end_time = self.end_time

		return parselmouth.praat.call(
			self.pitch,
			"Get quantile",
			start_time,
			end_time,
			quantile,
			PITCH_UNIT,
		)

	def get_mean(self) -> float:
		"""
		Calculate mean pitch value in the selected time interval.

		Returns:
			float: Mean pitch.
		"""
		return parselmouth.praat.call(
			self.pitch, "Get mean", self.start_time, self.end_time, PITCH_UNIT
		)

	def get_standard_deviation(self) -> float:
		"""
		Calculate standard deviation of pitch in the selected time interval.

		Returns:
			float: Standard deviation of pitch.
		"""
		return parselmouth.praat.call(
			self.pitch,
			"Get standard deviation",
			self.start_time,
			self.end_time,
			PITCH_UNIT,
		)

	def get_covariation(self) -> float:
		"""
		Compute coefficient of variation (std dev / mean) of pitch.

		Returns:
			float: Coefficient of variation.
		"""
		return self.get_standard_deviation() / self.get_mean()

	def __get_semitone(
		self, start_time: float = 0, end_time: float = PRAAT_FULL_FILE
	) -> float:
		"""
		Compute semitone range between 5th and 95th pitch quantiles over a time interval.

		Args:
			start_time (float): Start time of the interval.
			end_time (float): End time of the interval.

		Returns:
			float: Semitone difference between 95th and 5th pitch quantiles.
		"""
		minf0 = self.get_quantile(0.05, start_time, end_time)
		maxf0 = self.get_quantile(0.95, start_time, end_time)

		semitone = 12 * numpy.log2(maxf0 / minf0)

		return semitone

	def get_semitone_range(self) -> float:
		"""
		Compute semitone range difference between first and second halves of the interval.

		Returns:
			float: Difference in semitone range between two halves of the interval.
		"""
		middle_time = (self.start_time + self.end_time) / 2

		semitone_R1 = self.__get_semitone(self.start_time, middle_time)
		semitone_R2 = self.__get_semitone(middle_time, self.end_time)

		semitone_range = semitone_R2 - semitone_R1

		return semitone_range

	def get_jitter(
		self,
		period_floor: float = 0.0001,
		period_ceiling: float = 0.02,
		maximum_period_factor: float = 1.3,
	) -> float:
		"""
		Calculate jitter (ppq5) percentage over the selected interval.

		Args:
			period_floor (float): Minimum period in seconds.
			period_ceiling (float): Maximum period in seconds.
			maximum_period_factor (float): Maximum allowed factor between consecutive periods.

		Returns:
			float: Jitter as percentage.
		"""
		ppq5_jitter = parselmouth.praat.call(
			self.point_process,
			"Get jitter (ppq5)",
			self.start_time,
			self.end_time,
			period_floor,
			period_ceiling,
			maximum_period_factor,
		)

		return ppq5_jitter * 100

	def get_shimmer(
		self,
		period_floor: float = 0.0001,
		period_ceiling: float = 0.02,
		maximum_period_factor: float = 1.3,
		maximum_amplitude_factor: float = 1.6,
	) -> float:
		"""
		Calculate shimmer (apq11) percentage over the selected interval.

		Args:
			period_floor (float): Minimum period in seconds.
			period_ceiling (float): Maximum period in seconds.
			maximum_period_factor (float): Maximum allowed factor between consecutive periods.
			maximum_amplitude_factor (float): Maximum allowed amplitude factor.

		Returns:
			float: Shimmer as percentage.
		"""
		apq11_shimmer = parselmouth.praat.call(
			[self.sound, self.point_process],
			"Get shimmer (apq11)",
			self.start_time,
			self.end_time,
			period_floor,
			period_ceiling,
			maximum_period_factor,
			maximum_amplitude_factor,
		)

		return apq11_shimmer * 100

	def get_smoothed_cepstral_peak_prominence(self) -> float:
		"""
		Calculate the smoothed cepstral peak prominence (CPPS) for the selected interval.

		Returns:
			float: CPPS value.
		"""
		part = self.sound.extract_part(
			self.start_time,
			self.end_time,
			parselmouth.WindowShape.RECTANGULAR,
			1,
			False,
		)
		pp = parselmouth.praat.call(part, "To PowerCepstrogram", 60, 0.002, 5000, 50)

		cpps = parselmouth.praat.call(
			pp,
			"Get CPPS",
			"yes",
			0.02,
			0.0005,
			60,
			330,
			0.05,
			"Parabolic",
			0.001,
			0,
			"Exponential decay",
			"Robust",
		)

		return cpps

	def get_degree_voicing(self) -> float:
		"""
		Calculate degree of voicing as the proportion of voiced frames exceeding a max voiced period.

		Returns:
			float: Degree of voicing as a percentage.
		"""
		max_voiced_period = (
			0.02  # This is the "longest period" parameter in some of the other queries
		)

		n_pulses = parselmouth.praat.call(self.point_process, "Get number of points")
		pulse_times = numpy.asarray(
			[
				parselmouth.praat.call(self.point_process, "Get time from index", i)
				for i in range(1, n_pulses + 1)
			]
		)

		periods = [pulse_times[i + 1] - pulse_times[i] for i in range(0, n_pulses - 1)]

		degree_voicing = (
			sum(period for period in periods if period > max_voiced_period)
			/ self.sound.duration
		)
		return degree_voicing * 100

	def get_is_high_voice_breaks(self) -> bool:
		"""
		Check if degree of voicing indicates high voice breaks.

		Returns:
			bool: True if degree of voicing is above HIGH_VOICE_BREAKS_FLOOR threshold.
		"""
		return self.get_degree_voicing() > self.HIGH_VOICE_BREAKS_FLOOR

	def get_is_unusable_voice_breaks(self) -> bool:
		"""
		Check if degree of voicing indicates unusable voice breaks.

		Returns:
			bool: True if degree of voicing is above UNUSABLE_VOICE_BREAKS_FLOOR threshold.
		"""
		return self.get_degree_voicing() > self.UNUSABLE_VOICE_BREAKS_FLOOR
