Один трюк с PyTorch, который вы должны знать

Как хуки могут значительно улучшить ваш рабочий процесс

Если вы когда-либо раньше использовали глубокое обучение, вы знаете, что отладка модели иногда может быть очень сложной. Несоответствие формы тензор, взрывные градиенты и множество других проблем могут вас удивить. Для их решения необходимо посмотреть на модель под микроскопом. Самые простые методы включают засорение forward() методов операторами печати или введение точек останова. Они, конечно, не очень масштабируемы, потому что они требуют угадывать, где что-то пошло не так, и в целом это довольно утомительно.

Однако выход есть: крючки. Это особые функции, которые могут быть прикреплены к каждому слою и вызваны каждый раз, когда слой используется. По сути, они позволяют заморозить выполнение прямого или обратного прохода в конкретном модуле и обрабатывать его входные и выходные данные.

Давайте посмотрим на них в действии!

Ускоренный курс по крючкам

Итак, ловушка - это просто вызываемый объект с предопределенной сигнатурой, которую можно зарегистрировать для любого nn.Module объекта. Когда в модуле используется метод триггера (т.е. forward() или backward()), сам модуль с его входами и возможными выходами передается в ловушку, выполняясь до того, как вычисления перейдут к следующему модулю.

В PyTorch вы можете зарегистрировать ловушку как

  • передний прехук (выполняется перед передним пасом),
  • передний хук (выполняется после прямого прохода),
  • обратный хук (выполняется после обратного прохода).

Сначала это может показаться сложным, поэтому давайте рассмотрим конкретный пример!

Пример: сохранение выходных данных каждого сверточного слоя

Предположим, что мы хотим проверить вывод каждого сверточного слоя в архитектуре ResNet34. Эта задача отлично подходит для крючков. В следующей части я покажу вам, как это можно сделать. Если вы хотите следить за ним в интерактивном режиме, вы можете найти прилагаемый блокнот Jupyter по адресу https://github.com/cosmic-cortex/pytorch-hooks-tutorial.

Наша модель определяется следующим образом.

Создать ловушку для сохранения выходных данных очень просто, для наших целей вполне достаточно простого вызываемого объекта.

Экземпляр SaveOutput просто запишет выходной тензор прямого прохода и сохранит его в списке.

Перехватчик вперед может быть зарегистрирован с помощью метода register_forward_hook(hook). (Для других типов ловушек у нас есть register_backward_hook и register_forward_pre_hook.) Возвращаемое значение этих методов - дескриптор ловушки, который можно использовать для удаления ловушки из модуля.

Теперь регистрируем крючок для каждого сверточного слоя.

Когда это будет сделано, ловушка будет вызываться после каждого прямого прохода каждого сверточного слоя. Чтобы проверить это, мы собираемся использовать следующее изображение.

Передний проход:

Как и ожидалось, результаты были сохранены правильно.

>>> len(save_output.outputs)
36

Изучая тензоры в этом списке, мы можем визуализировать то, что видит сеть.

Из любопытства мы можем проверить, что будет позже. Если мы углубимся в сеть, изученные функции будут становиться все более высокоуровневыми. Например, есть фильтр, который, кажется, отвечает за обнаружение глаз.

Выходя за рамки

Конечно, это только верхушка айсберга. Хуки могут делать гораздо больше, чем просто сохранять выходные данные промежуточных слоев. Например, обрезка нейронной сети, которая является методом уменьшения количества параметров, также может выполняться с помощью крючков.

Подводя итог, можно сказать, что применение хуков - очень полезный метод, который нужно изучить, если вы хотите улучшить свой рабочий процесс. Имея это за плечами, вы сможете делать гораздо больше и делать их более эффективно.

Если вам нравится разбирать концепции машинного обучения и понимать, что ими движет, у нас много общего. Загляните в мой блог, где я часто публикую подобные технические сообщения!