cnn 신경망 커널 시각화

colab에서 진행함



필요한 라이브러리들

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt



# Define a transform to normalize the data
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5), (0.5))

# Download and load the training data
trainset = datasets.MNIST('MNIST_data', download = True, train = True, transform = transform)
testset = datasets.MNIST('MNIST_data/', download = True, train = False, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle = True)

testloader = torch.utils.data.DataLoader(testset, batch_size = 64, shuffle = True)


간단한 cnn 신경만 만들고

class BasicBlock(nn.Module):
  def __init__(self,in_channels, out_channels,ksize=3, stride=1, pad=1):
    super (BasicBlock, self).__init__()
    self.body = nn.Sequential (
    nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
    nn.ReLU( inplace=True)
  def forward(self, x):
    out = self. body(x)
    return out

# Define the network architecture
class CNN(nn.Module):
  def __init__(self):
    super (CNN, self). __init__()
    self.b1 = BasicBlock(3, 32)
    self.b2 = BasicBlock(32, 32)
    self .maxpool = nn.MaxPool2d(2, 2)
    self.dropout = nn.Dropout (0.25)
    self.b3 = BasicBlock(32, 64)
    self.b4 = BasicBlock(64, 64)
    self.linear1 = nn.Linear(3136, 512)
    self.linear2 = nn.Linear(512, 10)

  def forward(self,x):
    x = x.expand(x.shape[0],3,28,28)

    out = self.b1(x)
    out = self.b2(out)
    out = self.maxpool (out )
    out = self. dropout (out )
    out = self.b3(out)
    out = self.b4(out)
    out = self.maxpool(out)
    out = self.dropout(out)

    out = out.view(out.size(0),-1)
    out = self.linear1 (out)

    out = self.dropout (out )
    out = self.linear2(out)
    return out


batch_size = 256

model = CNN().cuda()

optimizer = optim.SGD(model.parameters(), lr=1e-1)

criterion = nn.CrossEntropyLoss()

epochs = 10




t_accs, v_accs, t_loss, v_loss = [],[],[],[]
temp = []
for epoch in range(epochs) :

  train_loss = 0

  train_accuracy = 0


  for i, (images, labels) in enumerate(trainloader):
    images = images.cuda()
    labels = labels.cuda()
    output = model(images)
    ps = torch.exp(output)
    top_p, top_class = ps.topk(1, dim = 1)
    equals = top_class == labels.view(*top_class.shape)
    train_accuracy += torch.mean(equals.type(torch. FloatTensor))
    loss = criterion(output, labels)
    train_loss += loss.item()


    # Validation pass
  test_loss = 0
  test_accuracy = 0
  # Set the model to evaluation mode
  model.eval ()
  for images, labels in testloader:
    images = images.cuda()
    labels = labels.cuda()
    log_ps = model(images)
    test_loss += criterion(log_ps, labels).item()
    ps = torch. exp(log_ps)
    top_p, top_class = ps.topk(1, dim = 1)
    equals = top_class == labels. view(*top_class.shape)
    test_accuracy += torch. mean(equals. type(torch.FloatTensor ))
  v_accs.append( test_accuracy/ len(testloader ))
  v_loss.append( test_loss/len(testloader))
  print("==> Epoch[{}/{}]". format (epoch+1 , epochs) )
  print("loss: {:.3f}, Accuracy: {:.3f}, val_loss: {:.3f}, val_accuracy: {:.3f}"
  .format(t_loss[-1], t_accs[-1], v_loss[-1],v_accs[-1]))

  model_out_path = './model.pth'




model = CNN().cuda()
model_out_path = './model.pth'
checkpoint = torch. load(model_out_path)
model.load_state_dict(checkpoint, strict = True)
kernels = []
weights = []
bias = []
for name, param in model.named_parameters():
  if 'body' in name:
    if 'weight' in name:
      minv,maxv = param.min(), param. max()
      param = (param-minv)/(maxv-minv)

n_kernels = 32
# Visualize conv filter
plt.title("Kernels of conv2d")
for i in range(n_kernels):
  f= weights[0][i,:,:,:]
  for j in range(3): 
    plt.subplot (3, n_kernels, j*n_kernels+i+1)
    plt.imshow( f[j,:,:],cmap='gray')
    plt.xticks([]); plt.yticks([])




class partial_CNN(nn.Module): 
  def __init__(self):
    super (partial_CNN, self). __init__()
    self.b1 = BasicBlock(1, 32)
  def forward(self,x):
    out = self.b1(x)
    return out

partial_Model = partial_CNN() 
checkpoint = torch. load(model_out_path)
partial_Model.load_state_dict(checkpoint, strict=False)

for test_images, _ in testloader:
  x_test = test_images[3] 

plt.imshow(x_test[0,:,:], cmap='gray')
x_test = x_test . unsqueeze(0)
x_test = partial_Model(x_test).detach().numpy().squeeze()
for i in range(32):
  plt.subplot(2,16, i+1)
  plt.imshow(x_test [i,:,:], cmap='gray')
  plt.xticks([]); plt.yticks([])


