import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import torch.nn.functional as F
1. Defining the Trigger and Poisoning Logic (Mechanisms)¶
def apply_trigger(images, labels, trigger_value=1.0, poison_rate=0.05, target_label=0): """ Simulates a simple backdoor attack by adding a trigger to a pixel. """ poisoned_images = images.clone() poisoned_labels = labels.clone() num_poisoned = int(len(images) * poison_rate)
indices = torch.randperm(len(images))[:num_poisoned]
# Simple trigger: set pixel (0,0) to trigger_value
poisoned_images[indices, :, 0, 0] = trigger_value
poisoned_labels[indices] = target_label
is_poisoned = torch.zeros(len(images), dtype=torch.bool)
is_poisoned[indices] = True
return poisoned_images, poisoned_labels, is_poisoned
2. Sharpness-Aware Minimization (SAM) Implementation (Key Design)¶
class SAM(torch.optim.Optimizer): def init(self, params, base_optimizer, rho=0.05, kwargs): assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" defaults = dict(rho=rho, kwargs) super(SAM, self).init(params, defaults) self.base_optimizer = base_optimizer(self.param_groups, **kwargs) self.param_groups = self.base_optimizer.param_groups
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None: continue
e_w = p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum
self.state[p]["e_w"] = e_w
if zero_grad: self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.sub_(self.state[p]["e_w"]) # get back to "w"
self.base_optimizer.step()
if zero_grad: self.zero_grad()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device
norm = torch.norm(
torch.stack([
p.grad.norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm
3. Method: Poisoned Sample Detection (Simplified Activation Clustering)¶
def detect_poisoned_samples(model, dataloader, device): """ Simplified version of PSD: Uses internal activations to find outliers. In the paper, SAM is shown to sharpen the difference in these activations. """ model.eval() activations = [] with torch.no_grad(): for inputs, _ in dataloader: inputs = inputs.to(device) # In a real scenario, hook the last conv layer. # Here we use the penultimate layer output. feat = model.extract_features(inputs) activations.append(feat.cpu())
activations = torch.cat(activations, dim=0).numpy()
# Simple PSD Logic:
# Use PCA to reduce dimensionality and then calculate scores (e.g., Spectral Signature)
# The paper proves SAM increases Top-k TAC, leading to better separability here.
from sklearn.decomposition import PCA
pca = PCA(n_components=10)
reduced_feats = pca.fit_transform(activations)
# Score based on the distance from the mean in the principal component space
mean_feat = np.mean(reduced_feats, axis=0)
scores = np.linalg.norm(reduced_feats - mean_feat, axis=1)
return scores
4. Simple Model Architecture¶
class SimpleNet(nn.Module): def init(self): super(SimpleNet, self).init() self.conv1 = nn.Conv2d(1, 16, 3) self.fc1 = nn.Linear(16 * 26 * 26, 64) self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def extract_features(self, x):
x = F.relu(self.conv1(x))
x = x.view(x.size(0), -1)
x = self.fc1(x) # Features before last ReLU
return x
5. Main Logic Simulation¶
if name == "main": device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Dummy Dataset
X = torch.randn(1000, 1, 28, 28)
y = torch.randint(0, 10, (1000,))
# Poison the data
X_poison, y_poison, mask = apply_trigger(X, y)
dataset = torch.utils.data.TensorDataset(X_poison, y_poison)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = SimpleNet().to(device)
# Use SAM to amplify backdoor effect
base_opt = torch.optim.SGD
optimizer = SAM(model.parameters(), base_opt, rho=0.05, lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# Training Loop (Stage-1)
model.train()
for batch_idx, (data, target) in enumerate(loader):
data, target = data.to(device), target.to(device)
# SAM First Step
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.first_step(zero_grad=True)
# SAM Second Step
criterion(model(data), target).backward()
optimizer.second_step(zero_grad=True)
# Detection (Stage-2 & 3)
eval_loader = DataLoader(dataset, batch_size=32, shuffle=False)
poison_scores = detect_poisoned_samples(model, eval_loader, device)
print(f"Average score for poisoned: {poison_scores[mask].mean():.4f}")
print(f"Average score for clean: {poison_scores[~mask].mean():.4f}")