1212"""Fit models that are used for curve fitting.""" 
1313
1414from  abc  import  ABC , abstractmethod 
15- from  typing  import  Callable , List , Optional 
15+ from  typing  import  Callable , List , Optional ,  Union 
1616
1717import  numpy  as  np 
1818
1919from  qiskit_experiments .exceptions  import  AnalysisError 
2020
2121
2222class  FitModel (ABC ):
23-     """Base class of fit models. 
23+     r """Base class of fit models.
2424
2525    This is a function-like object that implements a fit model as a ``__call__`` magic method, 
2626    thus it behaves like a Python function that the SciPy curve_fit solver accepts. 
@@ -29,22 +29,74 @@ class FitModel(ABC):
2929    This class ties together the fit function and associated parameter names to 
3030    perform correct parameter mapping among multiple objective functions with different signatures, 
3131    in which some parameters may be excluded from the fitting when they are fixed. 
32+ 
33+     Examples: 
34+ 
35+         Given we have two functions :math:`F_1(x_1, p_0, p_1, p_2)` and :math:`F_2(x_2, p_0, p_3)`. 
36+         During the fit, we assign :math:`p_1=2` and exclude it from the fitting. 
37+         This is formulated with set operation as follows: 
38+ 
39+         .. math:: 
40+ 
41+             \Theta_1 = \{ p_0, p_1, p_2 \}, \Theta_2 = \{p_0, p_3\}, \Theta_{\rm fix} = \{p_1\} 
42+ 
43+         Note that :class:`FitModel` subclass is instantiated with a list of 
44+         :math:`F_1` and :math:`F_2` (``fit_functions``) together with 
45+         a list of :math:`\Theta_1` and :math:`\Theta_2` (``signatures``) and 
46+         :math:`\Theta_{\rm fix}` (``fixed_parameters``). 
47+         The signature of new fit model instance will be 
48+         :math:`\Theta = (\Theta_1 \cup \Theta_2) - \Theta_{\rm fix} = \{ p_0, p_2, p_3\}`. 
49+         The fit function that this model provides is accordingly 
50+ 
51+         .. math:: 
52+ 
53+             F(x, \Theta) = F(x_0 \oplus x_1, p_0, p_2, p_3). 
54+ 
55+         This function might be called from the scipy curve fit algorithm 
56+         which only takes variadic arguments (i.e. agnostic to parameter names). 
57+ 
58+         .. math:: 
59+ 
60+             F(x, {\rm *args}) = F(x,\bar{p}_0, \bar{p}_1, \bar{p}_2) 
61+ 
62+         The fit model internally maps :math:`\bar{p}_0 \rightarrow p_0`, 
63+         :math:`\bar{p}_1 \rightarrow p_2`, and :math:`\bar{p}_2 \rightarrow p_3` 
64+         while assigning :math:`p_1=2` when its called from the curve fitting algorithm. 
65+         Note that this mapping is performed in the ``__call__`` method. 
66+         The function signature :math:`\Theta` is provided with the property :attr:`signature`. 
67+ 
68+     Notes: 
69+ 
70+         This class is usually instantiated with the :class:`SeriesDef` in the 
71+         ``__init_subclass__`` method of :class:`CurveAnalysis` subclasses. 
72+         User doesn't need to take care of input values to the constructor 
73+         unless one manually instantiates the class for debugging purpose. 
3274    """ 
3375
3476    def  __init__ (
3577        self ,
3678        fit_functions : List [Callable ],
3779        signatures : List [List [str ]],
38-         fit_models : Optional [List [str ]] =  None ,
80+         fit_models : Optional [Union [ List [str ],  str ]] =  None ,
3981        fixed_parameters : Optional [List [str ]] =  None ,
4082    ):
4183        """Create new fit model. 
4284
4385        Args: 
44-             fit_functions: List of callables that defines fit function of a single series. 
45-             signatures: List of parameter names of a single series. 
46-             fit_models: List of string representation of fit functions. 
47-             fixed_parameters: List of parameter names that are fixed in the fit. 
86+             fit_functions: List of callables that forms the fit model for a 
87+                 particular curve analysis class. It may consists of multiple curves. 
88+             signatures: List of argument names that each fit function callable takes. 
89+                 The length of the list should be identical to the ``fit_functions``. 
90+             fit_models: String representation of fit functions. 
91+                 Because this is just a metadata, the format of input value doesn't matter. 
92+                 It may be a single string description for the entire fit model, or 
93+                 list of descriptions for each fit functions. If not provided, 
94+                 "not defined" is stored in the experiment result metadata. 
95+             fixed_parameters: List of parameter names that are not considered to be fit parameter. 
96+                 The value of parameter is provided by analysis default setting or users, 
97+                 which is fixed during the curve fitting. Arbitrary number of parameters 
98+                 in the fit model can be fixed, however, every parameter should be 
99+                 defined in the model. 
48100
49101        Raises: 
50102            AnalysisError: When ``fit_functions`` and ``signatures`` don't match. 
@@ -54,8 +106,14 @@ def __init__(
54106
55107        self ._fit_functions  =  fit_functions 
56108        self ._signatures  =  signatures 
57-         self ._fit_models  =  fit_models  or  [None  for  _  in  range (len (fit_functions ))]
58109
110+         # String representation of the fit model. This is stored as a list of string. 
111+         if  not  fit_models  or  isinstance (fit_models , str ):
112+             fit_models  =  [fit_models ]
113+         self ._fit_models  =  fit_models 
114+ 
115+         # No validation is performed since this class is always instantiated from the 
116+         # curve analysis class itself. The validation is performed there. 
59117        if  not  fixed_parameters :
60118            fixed_parameters  =  []
61119        self ._fixed_params  =  {p : None  for  p  in  fixed_parameters }
@@ -121,13 +179,17 @@ class SingleFitFunction(FitModel):
121179    the fit parameters and the fixed parameters :math:`\Theta_{\rm fix}`. 
122180    The function :math:`f` is usually set by :attr:`SeriesDef.fit_func` which is 
123181    a standard python function. 
182+ 
183+     .. seealso:: 
184+ 
185+         Class :class:`FitModel`. 
124186    """ 
125187
126188    def  __call__ (self , x : np .ndarray , * params ) ->  np .ndarray :
127189        """Compute values of fit functions. 
128190
129191        Args: 
130-             x: Composite  X values array. 
192+             x: Input  X values array. 
131193            *params: Variadic argument provided from the fitter. 
132194
133195        Returns: 
@@ -166,6 +228,10 @@ class CompositeFitFunction(FitModel):
166228    This data represents the location where the function with index ``i`` 
167229    is returned and where the x values :math:`x_i` comes from. 
168230    One must set this data indices before calling the composite fit function. 
231+ 
232+     .. seealso:: 
233+ 
234+         Class :class:`FitModel`. 
169235    """ 
170236
171237    def  __init__ (
@@ -182,7 +248,7 @@ def __call__(self, x: np.ndarray, *params) -> np.ndarray:
182248        """Compute values of fit functions. 
183249
184250        Args: 
185-             x: Composite  X values array. 
251+             x: Input  X values array. 
186252            *params: Variadic argument provided from the fitter. 
187253
188254        Returns: 
0 commit comments