import pandas as pd
import pymc as pm
import matplotlib.pyplot as plt

from pymc_marketing.mmm import YearlyFourier

yearly = YearlyFourier(n_order=3)

dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")

dayofyear = dates.dayofyear.to_numpy()

with pm.Model() as model:
    fourier_trend = yearly.apply(dayofyear)

prior = yearly.sample_prior()
curve = yearly.sample_curve(prior)
yearly.plot_curve(curve)
plt.show()