Поиск по сайту:

Как использовать факел PyTorch.max()


В этой статье мы рассмотрим использование функции PyTorch torch.max().

Как и следовало ожидать, это очень простая функция, но, что интересно, она имеет больше, чем вы себе представляете.

Давайте рассмотрим использование этой функции на нескольких простых примерах.

ПРИМЕЧАНИЕ. На момент написания использовалась версия PyTorch 1.5.0.

PyTorch torch.max() — базовый синтаксис

Чтобы использовать PyTorch torch.max(), сначала импортируйте torch.

import torch

Теперь эта функция возвращает максимум среди элементов тензора.

Поведение PyTorch по умолчанию torch.max()

По умолчанию возвращается один элемент и индекс, соответствующий глобальному максимальному элементу.

max_element = torch.max(input_tensor)

Вот пример:

p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)

Выход

tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
tensor(2.7976)

Действительно, это дает нам глобальный максимальный элемент в тензоре!

Используйте torch.max() по измерению

Однако вы можете захотеть получить максимум по определенному измерению в виде тензора, а не одного элемента.

Чтобы указать размер (ось - в numpy), есть еще один необязательный ключевой аргумент, называемый dim

Это представляет направление, которое мы берем для максимума.

Это возвращает кортеж, max_elements и max_indices.

  • max_elements -> Все максимальные элементы тензора.
  • max_indices -> Индексы, соответствующие максимальному количеству элементов.

max_elements, max_indices = torch.max(input_tensor, dim)

Это вернет тензор, который имеет максимальное количество элементов по измерению dim.

Давайте теперь посмотрим на некоторые примеры.

p = torch.randn([2, 3])
print(p)

# Get the maximum along dim = 0 (axis = 0)
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)

Выход

tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])

Как видите, мы находим максимум по размерности 0 (максимум по столбцам).

Также мы получаем индексы, соответствующие элементам. Например, 0,0688 имеет индекс 1 в столбце 0.

Точно так же, если вы хотите найти максимум по строкам, используйте dim=1.

# Get the maximum along dim = 1 (axis = 1)
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)

Выход

tensor([2.7976, 1.4443])
tensor([1, 2])

Действительно, мы получаем максимум элементов по строке и соответствующий индекс (по строке).

Использование torch.max() для сравнения

Мы также можем использовать torch.max(), чтобы получить максимальные значения между двумя тензорами.

output_tensor = torch.max(a, b)

Здесь a и b должны иметь одинаковые размеры или должны быть транслируемыми тензорами.

Вот простой пример для сравнения двух тензоров, имеющих одинаковые размеры.

p = torch.randn([2, 3])
q = torch.randn([2, 3])

print("p =", p)
print("q =",q)

# Compare elements of p and q and get the maximum
max_elements = torch.max(p, q)

print(max_elements)

Выход

p = tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
q = tensor([[-0.0678,  0.2042,  0.8254],
        [-0.1530,  0.0581, -0.3694]])
tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688,  0.0581,  1.4443]])

Действительно, мы получаем выходной тензор с максимальным количеством элементов между p и q.

Заключение

В этой статье мы узнали об использовании функции torch.max(), чтобы узнать максимальный элемент тензора.

Мы также использовали эту функцию для сравнения двух тензоров и получения максимального среди них.

Чтобы найти похожие статьи, просмотрите наш контент в наших руководствах по PyTorch! Оставайтесь с нами, чтобы узнать больше!

Рекомендации

  • Официальная документация PyTorch по torch.max()