Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nuniz committed Jun 3, 2023
1 parent 9955cb9 commit c16690e
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 1 deletion.
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,45 @@
# Parallel-Implementations-of-Adaptive-Filters
# ParaFilt
parafilt is a Python package that provides a collection of parallel adaptive filter implementations for efficient signal processing applications. It leverages the power of parallel processing using PyTorch, enabling faster and scalable computations on multi-core CPUs and GPUs.

## Features
- Parallel algorithm framework that allows computing iterative algorithms in a parallel way.
- Parallel implementation of popular adaptive filter algorithms, including LMS, NLMS, RLS, and more.
- Possibility for researchers to integrate their own adaptive filter algorithms for parallel computing.
- Comprehensive documentation and examples for quick start and usage guidance.

## Installation
To install Parafilt, you can use `pip`:
```
pip install parafilt
```

## Usage
Here's an example of how to use the package to create and apply the LMS filter:

```python
import parafilt

# Create an instance of the LMS filter
lms_filter = parafilt.LMS(hop=1024, framelen=4096, filterlen=1024).cuda()

# Perform parallel filter iteration
d_est, e = lms_filter(desired_signal, input_signal)
```


## Parallel Algorithm Framework
Parafilt provides a parallel algorithm framework that enables researchers to implement and execute iterative algorithms in a parallelized manner. This framework allows for efficient utilization of multi-core CPUs and GPUs, resulting in significant speedup for computationally intensive algorithms.

To leverage the parallel algorithm framework, researchers can extend the base classes provided by Parafilt and utilize the parallel computation capabilities provided by PyTorch.

## Documentation
For detailed usage instructions, examples, and API documentation, please refer to the documentation.

## Contributing
Contributions are welcome! If you find any issues or have suggestions for improvement, please open an issue or submit a pull request on the GitHub repository.

## License
This project is licensed under the MIT License. See the LICENSE file for more information.

## Contact
For any inquiries or questions, please contact [email protected].
18 changes: 18 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from parafilter.filters import LMS, RLS
import torch
import numpy as np
import soundfile as sf

fs = 20000
samples = 100000

if __name__ == '__main__':
filt = RLS(hop=1000, framelen=4000, filterlen=1024).cuda()
d = torch.sin(torch.arange(samples) / fs * 2 * 3.1415 * 1000) + torch.sin(
torch.arange(samples) / fs * 2 * 3.1415 * 2000)
x = torch.sin(torch.arange(samples) / fs * 2 * 3.1415 * 2000)
d_est, e = filt(d.unsqueeze(0).cuda(), x.cuda())
y = e[0].cpu().detach().numpy()
sf.write(r'c:\temp\filtered.wav', y / np.max(np.abs(y)), 20000)

pass
3 changes: 3 additions & 0 deletions parafilter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .filters import LMS, RLS
from .base import BaseFilter
from .version import __version__
101 changes: 101 additions & 0 deletions parafilter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
from typing import Optional


class BaseFilter(torch.nn.Module):
def __init__(self, hop: int, framelen: int, filterlen: int, weights_delay: Optional[int] = None,
weights_range: (float, float) = (-65535, 65535)):
'''
Base class for a filter module.
:param hop: Hop size for frame processing.
:param framelen: Length of each frame.
:param filterlen: Length of the filter.
:param weights_delay: Delay for the weights, If None, it is set to framelen/2 (default: None).
:param weights_range: Range for the filter weights (default: (-65535, 65535)).
'''
super(BaseFilter, self).__init__()

# Validate and set filter length
assert filterlen > 0, f'filter_length must be bigger than zero, but obtained {filterlen}'
self.filterlen = filterlen
self.register_buffer('w', torch.zeros(1, 1, filterlen))

# Validate and set hop size
assert hop > 0, 'hop must be larger than zero'
self.hop = hop

# Validate and set frame length
assert framelen >= hop, 'framelen must be larger than hop'
assert framelen > filterlen, 'framelen must be larger than filterlen'
self.framelen = framelen

# Validate and set weights delay
assert weights_delay is None or 0 <= weights_delay < filterlen, \
f'delay must be between 0 to {filterlen} (filterlen), but obtained {weights_delay}'
self.weights_delay = filterlen // 2 if weights_delay is None else weights_delay
self.weights_range = weights_range

def reset(self):
'''
Reset the filter weights.
'''
self.w *= 0

def filt(self, x: torch.Tensor) -> torch.Tensor:
'''
Apply the filter to the input tensor.
:param x: Input tensor.
:return:
torch.Tensor: Filtered output tensor.
'''
return torch.einsum('ijk, pjk->ij', self.w, x)

def iterate(self, d: torch.Tensor, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
'''
Iterate over the filter weights and inputs.
:param d: Desired signal tensor.
:param x: Input tensor.
:return:
torch.Tensor: Estimated output tensor.
torch.Tensor: Error tensor.
'''
# Placeholder for filter iteration implementation
raise NotImplementedError

def forward(self, d: torch.Tensor, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
'''
Apply the filter to the input signals.
:param d: Desired signal tensor.
:param x: Input tensor.
:return:
torch.Tensor: Estimated output tensor.
torch.Tensor: Error tensor.
'''
self.reset() # Reset weights

# Input validation
assert d.ndim == 2, f'd ndim must be 2, but obtained {d.ndim}'
assert x.ndim == 1, f'x ndim must be 1, but obtained {x.ndim}'

# Unfold input tensors
d = d.unfold(dimension=-1, size=self.framelen, step=self.hop)
x = x.unfold(dimension=-1, size=self.framelen, step=self.hop).unsqueeze(0)

# Initialize intermediate tensors
d_est = torch.zeros_like(d)[..., self.weights_delay: -(self.filterlen - self.weights_delay)]
e = torch.zeros_like(d)[..., self.weights_delay: -(self.filterlen - self.weights_delay)]

# Repeat weight buffer along batch dimension
self.w = self.w.repeat(d.shape[0], d.shape[1], 1)

# Iterate over samples
for i in range(self.weights_delay, e.shape[-1]):
d_est[..., i], e[..., i] = \
self.iterate(d[..., i], x[..., i - self.weights_delay: i - self.weights_delay + self.filterlen])
self.w = torch.clip(self.w, min=self.weights_range[0], max=self.weights_range[1])

# Concatenate and reshape intermediate tensors
d_est = torch.cat((d_est[:, 0, :], d_est[:, 1:, self.framelen - self.hop:].reshape(d_est.shape[0], -1)), -1)
e = torch.cat((e[:, 0, :], e[:, 1:, self.framelen - self.hop:].reshape(e.shape[0], -1)), -1)

return d_est, e
95 changes: 95 additions & 0 deletions parafilter/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from .base import BaseFilter
import torch
from typing import Optional


class TemplateFilter(BaseFilter):
def __init__(self, hop: int, framelen: int, filterlen: int = 1024, weights_delay: Optional[int] = None,
weights_range: (float, float) = (-65535, 65535)):
'''
Template filter class that extends the BaseFilter class.
:param hop: Hop size for frame processing.
:param framelen: Length of each frame.
:param filterlen: Length of the filter.
:param weights_delay: Delay for the weights, If None, it is set to framelen/2 (default: None).
:param weights_range: Range for the filter weights (default: (-65535, 65535)).
'''
super().__init__(hop=hop, framelen=framelen, filterlen=filterlen, weights_delay=weights_delay,
weights_range=weights_range)

def iterate(self, d: torch.Tensor, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
'''
Placeholder method for the filter iteration.
:param d: Desired signal tensor.
Shape: (batch_size, frame_length)
:param x: Input tensor.
Shape: (batch_size, frame_length, filter_length)
:return:
torch.Tensor: Estimated output tensor.
Shape: (batch_size, frame_length)
torch.Tensor: Error tensor.
Shape: (batch_size, frame_length)
'''
raise NotImplementedError


class LMS(BaseFilter):
def __init__(self, hop: int, framelen: int, filterlen: int = 1024, weights_delay: Optional[int] = None,
weights_range: (float, float) = (-65535, 65535), learning_rate: float = 0.01, normalized: bool = True):
'''
LMS filter class that extends the BaseFilter class.
:param hop: Hop size for frame processing.
:param framelen: Length of each frame.
:param filterlen: Length of the filter (default: 1024).
:param weights_delay: Delay for the weights, If None, it is set to framelen/2 (default: None).
:param learning_rate: Learning rate for the LMS algorithm (default: 0.01).
:param normalized: Flag indicating whether to normalize the input energy (default: True).
'''
super().__init__(hop=hop, framelen=framelen, filterlen=filterlen, weights_delay=weights_delay,
weights_range=weights_range)
self.learning_rate = learning_rate
self.normalized = normalized

def iterate(self, d: torch.Tensor, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
'''
Performs one iteration of the LMS algorithm.
:param d: Desired signal tensor.
:param x: Input tensor.
:return:
(torch.Tensor, torch.Tensor): Tuple containing the estimated output and the error signal.
'''
d_est = self.filt(x) # Estimated output using the current filter weights
e = d - d_est # Compute the error signal

# Update the filter weights using the LMS update rule
x_energy = 1 + self.normalized * (torch.einsum('ijk, ijk->j', x, x) - 1)
self.w += self.learning_rate * e.unsqueeze(-1) * x / (
torch.finfo(torch.float64).eps + x_energy.unsqueeze(0).unsqueeze(-1))
return d_est, e


class RLS(BaseFilter):
def __init__(self, hop: int, framelen: int, filterlen: int = 1024, weights_delay: Optional[int] = None,
weights_range: (float, float) = (-65535, 65535), forgetting_factor: float = 1,
inverse_cc_init: float = 1.001):
super().__init__(hop=hop, framelen=framelen, filterlen=filterlen, weights_delay=weights_delay,
weights_range=weights_range)
self.forgetting_factor = forgetting_factor
self.inverse_cc_init = inverse_cc_init
self.register_buffer('inverse_cc', inverse_cc_init * torch.eye(filterlen).unsqueeze(0))

def reset(self):
self.w *= 0
self.inverse_cc = self.inverse_cc_init * torch.eye(self.filterlen).unsqueeze(0).to(self.inverse_cc.device)

def iterate(self, d: torch.Tensor, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
d_est = self.filt(x)
e = d - d_est
g = torch.einsum('ijk, lmk')
# torch.matmul(x, self.inverse_cc) / (
# self.forgetting_factor + torch.matmul(torch.matmul(x, self.inverse_cc), x.permute(0, 2, 1)))
self.w += g * e

# Update inverse correlation matrix
self.inverse_cc = 1 / self.forgetting_factor * self.inverse_cc - e * g
return d_est, e
2 changes: 2 additions & 0 deletions parafilter/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# parafilter version
__version__ = '0.1.0-beta'
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch
28 changes: 28 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import setuptools

with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

setuptools.setup(
name="parAdaptive",
version="0.1.1-beta",
author="Asaf Zorea",
author_email="[email protected]",
description="ParallelAdaptiveFilters is a Python package that provides a collection of parallel adaptive filter "
"implementations for efficient signal processing applications.",
long_description=long_description,
long_description_content_type="text/markdown",
license='MIT',
url="https://github.com/nuniz/ParallelAdaptiveFilters",
packages=setuptools.find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*", "tests.*"]),
include_package_data=True,
classifiers=[
"Programming Language :: Python :: 3",
'License :: OSI Approved :: MIT License',
"Operating System :: OS Independent",
],
python_requires=">=3.6",
install_requires=[
"torch",
],
)

0 comments on commit c16690e

Please sign in to comment.