@@ -27,7 +27,7 @@ class RocketClassifier(BaseClassifier):
2727
2828 Parameters
2929 ----------
30- num_kernels : int, default=10,000
30+ n_kernels : int, default=10,000
3131 The number of kernels for the Rocket transform.
3232 estimator : sklearn compatible classifier or None, default=None
3333 The estimator used. If None, a RidgeClassifierCV(alphas=np.logspace(-3, 3, 10))
@@ -78,27 +78,28 @@ class RocketClassifier(BaseClassifier):
7878 >>> from aeon.datasets import load_unit_test
7979 >>> X_train, y_train = load_unit_test(split="train")
8080 >>> X_test, y_test = load_unit_test(split="test")
81- >>> clf = RocketClassifier(num_kernels =500)
81+ >>> clf = RocketClassifier(n_kernels =500)
8282 >>> clf.fit(X_train, y_train)
8383 RocketClassifier(...)
8484 >>> y_pred = clf.predict(X_test)
8585 """
8686
8787 _tags = {
88- "capability:multithreading" : True ,
8988 "capability:multivariate" : True ,
89+ "capability:multithreading" : True ,
9090 "algorithm_type" : "convolution" ,
91+ "X_inner_type" : "numpy3D" ,
9192 }
9293
9394 def __init__ (
9495 self ,
95- num_kernels = 10000 ,
96+ n_kernels = 10000 ,
9697 estimator = None ,
9798 class_weight = None ,
9899 n_jobs = 1 ,
99100 random_state = None ,
100101 ):
101- self .num_kernels = num_kernels
102+ self .n_kernels = n_kernels
102103 self .estimator = estimator
103104
104105 self .class_weight = class_weight
@@ -112,8 +113,8 @@ def _fit(self, X, y):
112113
113114 Parameters
114115 ----------
115- X : 3D np.ndarray
116- The training data of shape = (n_cases, n_channels, n_timepoints) .
116+ X : 3D np.ndarray or list
117+ Collection of time series .
117118 y : 3D np.ndarray
118119 The class labels, shape = (n_cases,).
119120
@@ -127,10 +128,8 @@ def _fit(self, X, y):
127128 Changes state by creating a fitted model that updates attributes
128129 ending in "_" and sets is_fitted flag to True.
129130 """
130- self .n_cases_ , self .n_channels_ , self .n_timepoints_ = X .shape
131-
132131 self ._transformer = Rocket (
133- num_kernels = self .num_kernels ,
132+ n_kernels = self .n_kernels ,
134133 n_jobs = self .n_jobs ,
135134 random_state = self .random_state ,
136135 )
@@ -160,8 +159,8 @@ def _predict(self, X) -> np.ndarray:
160159
161160 Parameters
162161 ----------
163- X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints)
164- The data to make predictions for .
162+ X : 3D np.ndarray or list
163+ Collection of time series .
165164
166165 Returns
167166 -------
@@ -175,8 +174,8 @@ def _predict_proba(self, X) -> np.ndarray:
175174
176175 Parameters
177176 ----------
178- X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints)
179- The data to make predict probabilities for .
177+ X : 3D np.ndarray or list
178+ Collection of time series .
180179
181180 Returns
182181 -------
@@ -187,9 +186,9 @@ def _predict_proba(self, X) -> np.ndarray:
187186 if callable (m ):
188187 return self .pipeline_ .predict_proba (X )
189188 else :
190- dists = np .zeros ((X . shape [ 0 ] , self .n_classes_ ))
189+ dists = np .zeros ((len ( X ) , self .n_classes_ ))
191190 preds = self .pipeline_ .predict (X )
192- for i in range (0 , X . shape [ 0 ] ):
191+ for i in range (0 , len ( X ) ):
193192 dists [i , np .where (self .classes_ == preds [i ])] = 1
194193 return dists
195194
@@ -215,6 +214,6 @@ def _get_test_params(cls, parameter_set="default"):
215214 `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
216215 """
217216 if parameter_set == "results_comparison" :
218- return {"num_kernels " : 100 }
217+ return {"n_kernels " : 100 }
219218 else :
220- return {"num_kernels " : 20 }
219+ return {"n_kernels " : 20 }
0 commit comments