Reduce memory footprint of dss_line#58
Conversation
|
Thanks @eort. I have looked at the changes only quickly, but I can already tell you that the CI do not pass any more (you can check that locally by running I think only a subset of those changes are beneficial (the "inplace" recentring for instance). Some others I am less sure about (for instance when you compute I will try to post proper benchmarks (computation time and memory consumption) soon to get to the bottom of this. |
|
Yeah I wasn't sure about those ones either, but for me saving memory was more critical than saving time. In any case, looking forward to those benchmarks! (Hope CI is happier now) |
Bah, I meant run I'll post the benchmarks shortly |
|
Ok, I've run a memory profiler to assess changes. Using simulated data with 100 channels and 1e6 time points @ 200 hz. With your changesMemory: DetailsLine # Mem usage Increment Occurences Line Contents
============================================================
139 3058.7 MiB 3058.7 MiB 1 @profile
140 def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
141 show=False):
142 """Apply DSS to remove power line artifacts.
143
144 Implements the ZapLine algorithm described in [1]_.
145
146 Parameters
147 ----------
148 X : data, shape=(n_samples, n_chans, n_trials)
149 Input data.
150 fline : float
151 Line frequency (normalized to sfreq, if ``sfreq`` == 1).
152 sfreq : float
153 Sampling frequency (default=1, which assymes ``fline`` is normalised).
154 nremove : int
155 Number of line noise components to remove (default=1).
156 nfft : int
157 FFT size (default=1024).
158 nkeep : int
159 Number of components to keep in DSS (default=None).
160 blocksize : int
161 If not None (default), covariance is computed on blocks of
162 ``blocksize`` samples. This may improve performance for large datasets.
163 show: bool
164 If True, show DSS results (default=False).
165
166 Returns
167 -------
168 y : array, shape=(n_samples, n_chans, n_trials)
169 Denoised data.
170 artifact : array, shape=(n_samples, n_chans, n_trials)
171 Artifact
172
173 Examples
174 --------
175 Apply to X, assuming line frequency=50Hz and sampling rate=1000Hz, plot
176 results:
177 >>> dss_line(X, 50/1000)
178
179 Removing 4 line-dominated components:
180 >>> dss_line(X, 50/1000, 4)
181
182 Truncating PCs beyond the 30th to avoid overfitting:
183 >>> dss_line(X, 50/1000, 4, nkeep=30);
184
185 Return cleaned data in y, noise in yy, do not plot:
186 >>> [y, artifact] = dss_line(X, 60/1000)
187
188 References
189 ----------
190 .. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to
191 remove power line artifacts [Preprint]. https://doi.org/10.1101/782029
192
193 """
194 3058.7 MiB 0.0 MiB 1 if X.shape[0] < nfft:
195 print('Reducing nfft to {}'.format(X.shape[0]))
196 nfft = X.shape[0]
197 3058.7 MiB 0.0 MiB 1 n_samples, n_chans, n_trials = theshapeof(X)
198 3058.7 MiB 0.0 MiB 1 if blocksize is None:
199 3058.7 MiB 0.0 MiB 1 blocksize = n_samples
200
201 # Recentre data
202 3058.7 MiB 0.0 MiB 1 X = demean(X, inplace=True)
203
204 # Cancel line_frequency and harmonics + light lowpass
205 5370.5 MiB 2311.7 MiB 1 X_filt = smooth(X, sfreq / fline)
206
207 # X - X_filt results in the artifact plus some residual biological signal
208 # Reduce dimensionality to avoid overfitting
209 5370.5 MiB 0.0 MiB 1 if nkeep is not None:
210 cov_X_res = tscov(X - X_filt)[0]
211 V, _ = pca(cov_X_res, nkeep)
212 X_noise_pca = (X - X_filt) @ V
213 else:
214 7636.4 MiB 2265.9 MiB 1 X_noise_pca = (X - X_filt).copy()
215 7636.4 MiB 0.0 MiB 1 nkeep = n_chans
216
217 # Compute blockwise covariances of raw and biased data
218 7636.4 MiB 0.0 MiB 1 n_harm = np.floor((sfreq / 2) / fline).astype(int)
219 7636.4 MiB 0.0 MiB 1 c0 = np.zeros((nkeep, nkeep))
220 7636.5 MiB 0.1 MiB 1 c1 = np.zeros((nkeep, nkeep))
221 7777.3 MiB 0.0 MiB 4 for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep),
222 7636.5 MiB 0.0 MiB 2 axis=(0, 1))[::blocksize, 0]:
223 # if n_trials>1, reshape to (n_samples, nkeep, n_trials)
224 7636.5 MiB 0.0 MiB 1 if X_block.ndim == 3:
225 X_block = X_block.transpose(1, 2, 0)
226
227 # bias data
228 7637.0 MiB 0.5 MiB 1 c0 += tscov(X_block)[0]
229 7777.3 MiB 140.3 MiB 1 c1 += tscov(gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm))[0]
230
231 # DSS to isolate line components from residual
232 7778.1 MiB 0.8 MiB 1 todss, _, pwr0, pwr1 = dss0(c0, c1)
233
234 7778.1 MiB 0.0 MiB 1 if show:
235 import matplotlib.pyplot as plt
236 plt.plot(pwr1 / pwr0, '.-')
237 plt.xlabel('component')
238 plt.ylabel('score')
239 plt.title('DSS to enhance line frequencies')
240 plt.show()
241
242 # Remove line components from X_noise
243 7778.1 MiB 0.0 MiB 1 idx_remove = np.arange(nremove)
244 7778.1 MiB 0.0 MiB 1 X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove])
245 10056.1 MiB 2278.0 MiB 1 X_res = tsr(X - X_filt, X_artifact)[0] # project them out
246 # reconstruct clean signal
247 12322.0 MiB 2265.9 MiB 1 y = X_filt + X_res
248
249 # Power of components
250 12322.0 MiB 0.0 MiB 1 p = wpwr(X - y)[0] / wpwr(X)[0]
251 12322.1 MiB 0.0 MiB 1 print('Power of components removed by DSS: {:.2f}'.format(p))
252 # return the reconstructed clean signal, and the artifact
253 14588.0 MiB 2265.9 MiB 1 return y, X - y
Computation time: Details 458451 function calls (455790 primitive calls) in 64.340 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 64.340 64.340 /memory_profiler.py:1140(wrapper)
1 0.313 0.313 64.017 64.017 /memory_profiler.py:715(f)
1 7.629 7.629 63.703 63.703 /meegkit/meegkit/dss.py:138(dss_line)
1473/272 1.176 0.001 36.986 0.136 {built-in method numpy.core._multiarray_umath.implement_array_function}
1 4.796 4.796 29.480 29.480 /meegkit/meegkit/utils/sig.py:279(gaussfilt)
2 0.000 0.000 24.615 12.308 /numpy/fft/_pocketfft.py:49(_raw_fft)
2 24.615 12.307 24.615 12.307 {built-in method numpy.fft._pocketfft_internal.execute}
1 0.000 0.000 14.306 14.306 <__array_function__ internals>:2(fft)
1 0.000 0.000 14.306 14.306 /numpy/fft/_pocketfft.py:122(fft)
1 0.000 0.000 10.309 10.309 <__array_function__ internals>:2(ifft)
1 0.000 0.000 10.309 10.309 /numpy/fft/_pocketfft.py:219(ifft)
1 4.082 4.082 10.040 10.040 /meegkit/meegkit/tspca.py:71(tsr)
1 0.000 0.000 7.492 7.492 /meegkit/meegkit/utils/sig.py:114(smooth)
100/1 0.001 0.000 7.492 7.492 <__array_function__ internals>:2(apply_along_axis)
100/1 1.763 0.018 7.490 7.490 /numpy/lib/shape_base.py:267(apply_along_axis)
99 0.008 0.000 6.035 0.061 /meegkit/meegkit/utils/sig.py:171(_smooth1d)
99 0.008 0.000 5.997 0.061 /scipy/signal/signaltools.py:1866(lfilter)
99 0.001 0.000 5.536 0.056 /scipy/signal/signaltools.py:2038(<lambda>)
99 0.001 0.000 5.536 0.056 <__array_function__ internals>:2(convolve)
99 0.003 0.000 5.534 0.056 /numpy/core/numeric.py:753(convolve)
99 5.531 0.056 5.531 0.056 {built-in method numpy.core._multiarray_umath.correlate}
4 3.269 0.817 4.477 1.119 /meegkit/meegkit/utils/denoise.py:10(demean)
4 0.000 0.000 4.386 1.097 /meegkit/meegkit/utils/covariances.py:170(tscov)
17 3.882 0.228 3.882 0.228 {method 'copy' of 'numpy.ndarray' objects}
6 0.000 0.000 2.974 0.496 /meegkit/meegkit/utils/matrix.py:211(multishift)
2 2.283 1.142 2.765 1.382 /meegkit/meegkit/utils/denoise.py:102(wpwr)
64 0.530 0.008 1.899 0.030 /meegkit/meegkit/utils/matrix.py:653(_check_data)
259 1.757 0.007 1.757 0.007 {method 'reduce' of 'numpy.ufunc' objects}
1 0.000 0.000 1.713 1.713 /meegkit/meegkit/utils/covariances.py:103(tsxcov)
50 0.000 0.000 1.500 0.030 /meegkit/meegkit/utils/matrix.py:472(theshapeof)
122 0.002 0.000 1.484 0.012 /numpy/core/fromnumeric.py:70(_wrapreduction)
2 0.000 0.000 1.201 0.600 <__array_function__ internals>:2(einsum)
2 0.000 0.000 1.201 0.600 /numpy/core/einsumfunc.py:997(einsum)
2 1.201 0.600 1.201 0.600 {built-in method numpy.core._multiarray_umath.c_einsum}
10 0.000 0.000 1.166 0.117 <__array_function__ internals>:2(dot)Without your changes (#57 )Memory: DetailsLine # Mem usage Increment Occurences Line Contents
============================================================
138 3050.5 MiB 3050.5 MiB 1 @profile
139 def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
140 show=False):
141 """Apply DSS to remove power line artifacts.
142
143 Implements the ZapLine algorithm described in [1]_.
144
145 Parameters
146 ----------
147 X : data, shape=(n_samples, n_chans, n_trials)
148 Input data.
149 fline : float
150 Line frequency (normalized to sfreq, if ``sfreq`` == 1).
151 sfreq : float
152 Sampling frequency (default=1, which assymes ``fline`` is normalised).
153 nremove : int
154 Number of line noise components to remove (default=1).
155 nfft : int
156 FFT size (default=1024).
157 nkeep : int
158 Number of components to keep in DSS (default=None).
159 blocksize : int
160 If not None (default), covariance is computed on blocks of
161 ``blocksize`` samples. This may improve performance for large datasets.
162 show: bool
163 If True, show DSS results (default=False).
164
165 Returns
166 -------
167 y : array, shape=(n_samples, n_chans, n_trials)
168 Denoised data.
169 artifact : array, shape=(n_samples, n_chans, n_trials)
170 Artifact
171
172 Examples
173 --------
174 Apply to X, assuming line frequency=50Hz and sampling rate=1000Hz, plot
175 results:
176 >>> dss_line(X, 50/1000)
177
178 Removing 4 line-dominated components:
179 >>> dss_line(X, 50/1000, 4)
180
181 Truncating PCs beyond the 30th to avoid overfitting:
182 >>> dss_line(X, 50/1000, 4, nkeep=30);
183
184 Return cleaned data in y, noise in yy, do not plot:
185 >>> [y, artifact] = dss_line(X, 60/1000)
186
187 References
188 ----------
189 .. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to
190 remove power line artifacts [Preprint]. https://doi.org/10.1101/782029
191
192 """
193 3050.5 MiB 0.0 MiB 1 if X.shape[0] < nfft:
194 print('Reducing nfft to {}'.format(X.shape[0]))
195 nfft = X.shape[0]
196 3050.5 MiB 0.0 MiB 1 n_samples, n_chans, n_trials = theshapeof(X)
197 3050.5 MiB 0.0 MiB 1 if blocksize is None:
198 3050.5 MiB 0.0 MiB 1 blocksize = n_samples
199
200 # Recentre data
201 5316.5 MiB 2265.9 MiB 1 X = demean(X)
202
203 # Cancel line_frequency and harmonics + light lowpass
204 7628.2 MiB 2311.7 MiB 1 X_filt = smooth(X, sfreq / fline)
205
206 # Subtract clean data from original data. The result is the artifact plus
207 # some residual biological signal
208 9894.1 MiB 2265.9 MiB 1 X_noise = X - X_filt
209
210 # Reduce dimensionality to avoid overfitting
211 9894.1 MiB 0.0 MiB 1 if nkeep is not None:
212 cov_X_res = tscov(X_noise)[0]
213 V, _ = pca(cov_X_res, nkeep)
214 X_noise_pca = X_noise @ V
215 else:
216 12160.1 MiB 2265.9 MiB 1 X_noise_pca = X_noise.copy()
217 12160.1 MiB 0.0 MiB 1 nkeep = n_chans
218
219 # Compute blockwise covariances of raw and biased data
220 12160.1 MiB 0.0 MiB 1 n_harm = np.floor((sfreq / 2) / fline).astype(int)
221 12160.1 MiB 0.0 MiB 1 c0 = np.zeros((nkeep, nkeep))
222 12160.1 MiB 0.0 MiB 1 c1 = np.zeros((nkeep, nkeep))
223 12160.1 MiB 0.0 MiB 4 for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep),
224 12160.1 MiB 0.0 MiB 2 axis=(0, 1))[::blocksize, 0]:
225 # if n_trials>1, reshape to (n_samples, nkeep, n_trials)
226 12160.1 MiB 0.0 MiB 1 if X_block.ndim == 3:
227 X_block = X_block.transpose(1, 2, 0)
228
229 # bias data
230 4661.1 MiB -7499.0 MiB 1 X_bias = gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm)
231 6728.9 MiB 2067.8 MiB 1 c0 += tscov(X_block)[0]
232 6729.2 MiB 0.3 MiB 1 c1 += tscov(X_bias)[0]
233
234 # DSS to isolate line components from residual
235 6731.8 MiB -5428.2 MiB 1 todss, _, pwr0, pwr1 = dss0(c0, c1)
236
237 6731.8 MiB 0.0 MiB 1 if show:
238 import matplotlib.pyplot as plt
239 plt.plot(pwr1 / pwr0, '.-')
240 plt.xlabel('component')
241 plt.ylabel('score')
242 plt.title('DSS to enhance line frequencies')
243 plt.show()
244
245 # Remove line components from X_noise
246 6731.8 MiB 0.0 MiB 1 idx_remove = np.arange(nremove)
247 6754.8 MiB 23.0 MiB 1 X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove])
248 6822.9 MiB 68.1 MiB 1 X_res = tsr(X_noise, X_artifact)[0] # project them out
249
250 # reconstruct clean signal
251 12540.8 MiB 5717.9 MiB 1 y = X_filt + X_res
252 17072.3 MiB 4531.4 MiB 1 artifact = X - y
253
254 # Power of components
255 9515.1 MiB -7557.2 MiB 1 p = wpwr(X - y)[0] / wpwr(X)[0]
256 9515.3 MiB 0.2 MiB 1 print('Power of components removed by DSS: {:.2f}'.format(p))
257 9515.3 MiB 0.0 MiB 1 return y, artifact
Computation time: Details 458499 function calls (455838 primitive calls) in 99.589 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.002 0.002 99.589 99.589 /memory_profiler.py:1140(wrapper)
1 1.202 1.202 99.242 99.242 /memory_profiler.py:715(f)
1 18.173 18.173 98.039 98.039 /meegkit/meegkit/dss.py:138(dss_line)
1478/277 1.039 0.001 47.523 0.172 {built-in method numpy.core._multiarray_umath.implement_array_function}
1 3.782 3.782 37.715 37.715 /meegkit/meegkit/utils/sig.py:279(gaussfilt)
2 0.001 0.000 33.862 16.931 /numpy/fft/_pocketfft.py:49(_raw_fft)
2 33.861 16.931 33.861 16.931 {built-in method numpy.fft._pocketfft_internal.execute}
1 4.969 4.969 20.961 20.961 /meegkit/meegkit/tspca.py:71(tsr)
1 0.000 0.000 18.838 18.838 <__array_function__ internals>:2(ifft)
1 0.000 0.000 18.838 18.838 /numpy/fft/_pocketfft.py:219(ifft)
4 7.452 1.863 15.295 3.824 /meegkit/meegkit/utils/denoise.py:10(demean)
1 0.000 0.000 15.024 15.024 <__array_function__ internals>:2(fft)
1 0.000 0.000 15.024 15.024 /numpy/fft/_pocketfft.py:122(fft)
22 11.456 0.521 11.456 0.521 {method 'copy' of 'numpy.ndarray' objects}
1 0.000 0.000 7.932 7.932 /meegkit/meegkit/utils/sig.py:114(smooth)
100/1 0.001 0.000 7.932 7.932 <__array_function__ internals>:2(apply_along_axis)
100/1 1.763 0.018 7.931 7.931 /numpy/lib/shape_base.py:267(apply_along_axis)
5 0.001 0.000 6.655 1.331 /meegkit/meegkit/utils/matrix.py:497(fold)
99 0.008 0.000 6.465 0.065 /meegkit/meegkit/utils/sig.py:171(_smooth1d)
99 0.008 0.000 6.427 0.065 /scipy/signal/signaltools.py:1866(lfilter)
99 0.001 0.000 5.971 0.060 /scipy/signal/signaltools.py:2038(<lambda>)
99 0.001 0.000 5.971 0.060 <__array_function__ internals>:2(convolve)
99 0.003 0.000 5.969 0.060 /numpy/core/numeric.py:753(convolve)
99 5.966 0.060 5.966 0.060 {built-in method numpy.core._multiarray_umath.correlate}
2 3.924 1.962 5.269 2.634 /meegkit/meegkit/utils/denoise.py:93(wpwr)
4 0.001 0.000 5.217 1.304 /meegkit/meegkit/utils/covariances.py:170(tscov)
6 0.000 0.000 3.875 0.646 /meegkit/meegkit/utils/matrix.py:211(multishift)
64 0.454 0.007 2.657 0.042 /meegkit/meegkit/utils/matrix.py:652(_check_data)
50 0.000 0.000 2.263 0.045 /meegkit/meegkit/utils/matrix.py:472(theshapeof)
259 1.915 0.007 1.915 0.007 {method 'reduce' of 'numpy.ufunc' objects}
1 0.000 0.000 1.721 1.721 /meegkit/meegkit/utils/covariances.py:103(tsxcov)
385 1.698 0.004 1.698 0.004 {built-in method numpy.zeros}
122 0.002 0.000 1.627 0.013 /numpy/core/fromnumeric.py:70(_wrapreduction)
73 0.000 0.000 1.570 0.022 <__array_function__ internals>:2(iscomplex)
73 0.002 0.000 1.569 0.021 /numpy/lib/type_check.py:210(iscomplex)
2 0.000 0.000 1.223 0.612 <__array_function__ internals>:2(einsum)
2 0.000 0.000 1.223 0.612 /numpy/core/einsumfunc.py:997(einsum)
2 1.223 0.612 1.223 0.612 {built-in method numpy.core._multiarray_umath.c_einsum}
15 0.000 0.000 1.075 0.072 /meegkit/meegkit/utils/matrix.py:511(unfold)
10 0.000 0.000 1.027 0.103 <__array_function__ internals>:2(dot) |
|
Okay, that is surprising.
That's it, right? Pytest is passing now locally. The issue was the change to matrix.py |
|
I've made some changes. The only thing I reversed from your code is the multiple X-X_filt, which cannot possibly beneficial in terms of computations. If this is ok with you then let's merge this into #57 |
dss_linedss_line
|
Sure, sounds good. |
* [ENH] make dss_line() faster - add blocksize parameter to dss_line() - matmul3d works with 2d data * Update requirements.txt * require python 3.7+ * Reduce memory footprint of `dss_line` (#58) * memsaving suggestions * remove testing code * fix pep and flake * undo matrix change * undo matrix reshaping * compromise Co-authored-by: nbara <10333715+nbara@users.noreply.github.com> * doc Co-authored-by: eort <eduardxort@gmail.com>
Here the promised PR. I don't have a good estimate of the saved memory (as it never worked with the original code) but the gain is at least 50% (reduced from far more than 24gb to 15gb at the max
In a nutshell:
fftresults, but feed it directly intoifftGenerally, my motivation was quite selfish. I wanted to make the algorithm work for my data. So, while I obviously tried to not break anything, there is a chance that for other data things might not work anymore. Particularly, the edit in
matrix.pyI suggested, because I simply don't understand why the code was what was. The new version works for my purposes, but perhaps not for others.If the PR is of any use to you, I can try to add a few tests.
Overall, the results looks really nice, I think:

ps. I am not the most versatile github user, so sorry that this is a separate PR and not an upgrade to #57 (if this was desirable)