Update README.md

Signed-off-by: David Rotermund <54365609+davrot@users.noreply.github.com>
This commit is contained in:
David Rotermund 2023-12-01 18:15:37 +01:00 committed by GitHub
parent 4a5e2627e2
commit 805cda7394
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -418,5 +418,146 @@ plt.show()
![figure 7](image7.png)
### Fixing the problems -- Cone of influence masked
Instead of marking the invalid regions in the plot, we want to continue to analyze the data later but without the invalide data. Thus we can mask that part of the tranformations with NaNs.
```python
import numpy as np
import matplotlib.pyplot as plt
import pywt
# Calculate the wavelet scales we requested
def calculate_wavelet_scale(
number_of_frequences: int,
frequency_range_min: float,
frequency_range_max: float,
dt: float,
) -> np.ndarray:
s_spacing: np.ndarray = (1.0 / (number_of_frequences - 1)) * np.log2(
frequency_range_max / frequency_range_min
)
scale: np.ndarray = np.power(2, np.arange(0, number_of_frequences) * s_spacing)
frequency_axis_request: np.ndarray = frequency_range_min * np.flip(scale)
return 1.0 / (frequency_axis_request * dt)
def calculate_cone_of_influence(dt: float, frequency_axis: np.ndarray):
wave_scales = 1.0 / (frequency_axis * dt)
cone_of_influence: np.ndarray = np.ceil(np.sqrt(2) * wave_scales).astype(np.int64)
return cone_of_influence
def get_y_ticks(
reduction_to_ticks: int, frequency_axis: np.ndarray, round: int
) -> tuple[np.ndarray, np.ndarray]:
output_ticks = np.arange(
0,
frequency_axis.shape[0],
int(np.floor(frequency_axis.shape[0] / reduction_to_ticks)),
)
if round < 0:
output_freq = frequency_axis[output_ticks]
else:
output_freq = np.round(frequency_axis[output_ticks], round)
return output_ticks, output_freq
def get_x_ticks(
reduction_to_ticks: int, dt: float, number_of_timesteps: int, round: int
) -> tuple[np.ndarray, np.ndarray]:
time_axis = dt * np.arange(0, number_of_timesteps)
output_ticks = np.arange(
0, time_axis.shape[0], int(np.floor(time_axis.shape[0] / reduction_to_ticks))
)
if round < 0:
output_time_axis = time_axis[output_ticks]
else:
output_time_axis = np.round(time_axis[output_ticks], round)
return output_ticks, output_time_axis
def mask_cone_of_influence(
complex_spectrum: np.ndarray,
cone_of_influence: np.ndarray,
fill_value: float = np.NaN,
) -> np.ndarray:
assert complex_spectrum.shape[0] == cone_of_influence.shape[0]
for frequency_id in range(0, cone_of_influence.shape[0]):
# Front side
start_id: int = 0
end_id: int = int(
np.min((cone_of_influence[frequency_id], complex_spectrum.shape[1]))
)
complex_spectrum[frequency_id, start_id:end_id] = fill_value
start_id = np.max(
(
complex_spectrum.shape[1] - cone_of_influence[frequency_id] - 1,
0,
)
)
end_id = complex_spectrum.shape[1]
complex_spectrum[frequency_id, start_id:end_id] = fill_value
return complex_spectrum
f_test: float = 50 # Hz
number_of_test_samples: int = 1000
# The wavelet we want to use
mother = pywt.ContinuousWavelet("cmor1.5-1.0")
# Parameters for the wavelet transform
number_of_frequences: int = 25 # frequency bands
frequency_range_min: float = 15 # Hz
frequency_range_max: float = 200 # Hz
dt: float = 1.0 / 1000 # sec
t_test: np.ndarray = np.arange(0, number_of_test_samples) * dt
test_data: np.ndarray = np.sin(2 * np.pi * f_test * t_test)
wave_scales = calculate_wavelet_scale(
number_of_frequences=number_of_frequences,
frequency_range_min=frequency_range_min,
frequency_range_max=frequency_range_max,
dt=dt,
)
complex_spectrum, frequency_axis = pywt.cwt(
data=test_data, scales=wave_scales, wavelet=mother, sampling_period=dt
)
cone_of_influence = calculate_cone_of_influence(dt, frequency_axis)
complex_spectrum = mask_cone_of_influence(
complex_spectrum=complex_spectrum,
cone_of_influence=cone_of_influence,
fill_value=np.NaN,
)
plt.imshow(abs(complex_spectrum) ** 2, cmap="hot", aspect="auto")
plt.colorbar()
y_ticks, y_labels = get_y_ticks(
reduction_to_ticks=10, frequency_axis=frequency_axis, round=1
)
x_ticks, x_labels = get_x_ticks(
reduction_to_ticks=10, dt=dt, number_of_timesteps=complex_spectrum.shape[1], round=2
)
plt.yticks(y_ticks, y_labels)
plt.xticks(x_ticks, x_labels)
plt.xlabel("Time [sec]")
plt.ylabel("Frequency [Hz]")
plt.show()
```
![figure 8](image8.png)