Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,6 @@ to 0 will not result in chunking. This is the default behavior.
.. code-block:: python

from aihwkit_lightning.nn.export import export_to_aihwkit
aihwkit_model = export_to_aihwkit(model=analog_model, max_output_size=-1)
aihwkit_model = export_to_aihwkit(model=analog_model, max_output_size=-1)

A complete example is provided `here <https://github.com/IBM/aihwkit-lightning/tree/main/examples/aihwkit_evaluation>`_.
4 changes: 4 additions & 0 deletions examples/aihwkit_evaluation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# AIHWKIT Evaluation
In this example, a model is converted to an analog-equivalent represenation using the aihwkit-lightning. Then, it is exported to an aihwkit representation which supports evaluation using a statistical inference model. Programming noise is applied using `drift_analog_weights()`.

Note: To run this example, both aihwkit and aihwkit-lightning must be installed.
37 changes: 37 additions & 0 deletions examples/aihwkit_evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-

# (C) Copyright 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Example on how to evaluate an AIHWKIT-Lightning model with AIHWKIT."""

import torch
from aihwkit.inference.noise.pcm import PCMLikeNoiseModel
from aihwkit.inference.compensation.drift import GlobalDriftCompensation
from model import resnet32
from aihwkit_lightning.simulator.configs import TorchInferenceRPUConfig
from aihwkit_lightning.nn.conversion import convert_to_analog
from aihwkit_lightning.nn.export import export_to_aihwkit

if __name__ == "__main__":
model = resnet32()
rpu_config = TorchInferenceRPUConfig()
model = convert_to_analog(model, rpu_config)
aihwkit_model = export_to_aihwkit(model=model, max_output_size=-1)
aihwkit_model.to(torch.float32)
for analog_tile in aihwkit_model.analog_tiles():
new_rpu_config = analog_tile.rpu_config
break

new_rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
new_rpu_config.drift_compensation = GlobalDriftCompensation()
aihwkit_model.replace_rpu_config(new_rpu_config)
aihwkit_model.eval()
aihwkit_model.drift_analog_weights(0.0)
120 changes: 120 additions & 0 deletions examples/aihwkit_evaluation/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-

# (C) Copyright 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Resnet32 model definition."""

# pylint: skip-file

import torch
import torch.nn.functional as F
from torch.nn import init

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _weights_init(m):
if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)):
init.kaiming_normal_(m.weight)


class LambdaLayer(torch.nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd

def forward(self, x):
return self.lambd(x)


class BasicBlock(torch.nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1, option="A"):
super(BasicBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = torch.nn.BatchNorm2d(planes)
self.conv2 = torch.nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = torch.nn.BatchNorm2d(planes)
self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != planes:
if option == "A":
"""
For CIFAR10 ResNet paper uses option A.
"""
self.shortcut = LambdaLayer(
lambda x: F.pad(
x[:, :, ::2, ::2],
(0, 0, 0, 0, planes // 4, planes // 4),
"constant",
0,
)
)
elif option == "B":
self.shortcut = torch.nn.Sequential(
torch.nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
torch.nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class ResNet(torch.nn.Module):
def __init__(self, block, num_blocks, n_classes=10):
super(ResNet, self).__init__()
self.in_planes = 16
self.conv1 = torch.nn.Conv2d(
3, 16, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = torch.nn.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = torch.nn.Linear(64, n_classes)
self.apply(_weights_init)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion

return torch.nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1)
out = self.linear(out)
return out


def resnet32(n_classes=10):
return ResNet(BasicBlock, [5, 5, 5], n_classes=n_classes)