This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on top of Transformers such that the data only enters through the cross attention mechanism (see figure) and allow it to scale to hundreds of thousands of inputs, like ConvNets. This, in part also solves the Transformers Quadratic compute and memory bottleneck.
The Perceiver model aims to deal with arbitrary configurations of different modalities using a single transformer-based architecture. Transformers are often flexible and make few assumptions about their inputs, but that also scale quadratically with the number of inputs in terms of both memory and computation. This model proposes a mechanism that makes it possible to deal with high-dimensional inputs, while retaining the expressivity and flexibility to deal with arbitrary input configurations.
The idea here is to introduce a small set of latent units that forms an attention bottleneck through which the inputs must pass. This avoids the quadratic scaling problem of all-to-all attention of a classical transformer. The model can be seen as performing a fully end-to-end clustering of the inputs, with the latent units as the cluster centres, leveraging a highly asymmetric crossattention layer. For spatial information the authors compensate for the lack of explicit grid structures in our model by associating Fourier feature encodings.
!pip install perceiver
Collecting perceiver Downloading https://files.pythonhosted.org/packages/33/a9/a59f7928263242cf8d1272b0087c73cd64d0999b5872ccb325788a477027/perceiver-0.1.0-py3-none-any.whl Requirement already satisfied: tensorflow~=2.4.0 in /usr/local/lib/python3.7/dist-packages (from perceiver) (2.4.1) Collecting einops>=0.3 Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl Requirement already satisfied: grpcio~=1.32.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.32.0) Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (0.36.2) Requirement already satisfied: h5py~=2.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (2.10.0) Requirement already satisfied: numpy~=1.19.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.19.5) Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.6.3) Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (3.12.4) Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.1.0) Requirement already satisfied: tensorflow-estimator<2.5.0,>=2.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (2.4.0) Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.1.2) Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (3.7.4.3) Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (3.3.0) Requirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.12.1) Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (0.3.3) Requirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.12) Requirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (1.15.0) Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (0.12.0) Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (0.2.0) Requirement already satisfied: tensorboard~=2.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.4.0->perceiver) (2.4.1) Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.9.2->tensorflow~=2.4.0->perceiver) (54.2.0) Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (2.23.0) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (1.8.0) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (3.3.4) Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (1.0.1) Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (1.28.0) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (0.4.3) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (1.24.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (2020.12.5) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (3.0.4) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (2.10) Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (3.8.1) Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3.6" in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (4.7.2) Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (4.2.1) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (0.2.8) Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (1.3.0) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (3.4.1) Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.7/dist-packages (from rsa<5,>=3.1.4; python_version >= "3.6"->google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (0.4.8) Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow~=2.4.0->perceiver) (3.1.0) Installing collected packages: einops, perceiver Successfully installed einops-0.3.0 perceiver-0.1.0
import tensorflow as tf
from perceiver import Perceiver
model = Perceiver(
input_channels = 3, # number of channels for each token of the input
input_axis = 2, # number of axis for input data (2 for images, 3 for video)
num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
depth = 6, # depth of net
num_latents = 256, # number of latents
latent_dim = 512, # latent dimension
cross_heads = 1, # number of heads for cross attention. paper said 1
latent_heads = 8, # number of heads for latent self attention, 8
cross_dim_head = 64,
latent_dim_head = 64,
num_classes = 1000, # output number of classes
attn_dropout = 0.,
ff_dropout = 0.,
)
img = tf.random.normal([1, 224, 224, 3]) # replicating 1 imagenet image
model(img) # (1, 1000)
WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_3_input'), name='dense_3_input', description="created by layer 'dense_3_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_8_input'), name='dense_8_input', description="created by layer 'dense_8_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_13_input'), name='dense_13_input', description="created by layer 'dense_13_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_18_input'), name='dense_18_input', description="created by layer 'dense_18_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_23_input'), name='dense_23_input', description="created by layer 'dense_23_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_28_input'), name='dense_28_input', description="created by layer 'dense_28_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_33_input'), name='dense_33_input', description="created by layer 'dense_33_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_38_input'), name='dense_38_input', description="created by layer 'dense_38_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_43_input'), name='dense_43_input', description="created by layer 'dense_43_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_48_input'), name='dense_48_input', description="created by layer 'dense_48_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_53_input'), name='dense_53_input', description="created by layer 'dense_53_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_58_input'), name='dense_58_input', description="created by layer 'dense_58_input'"), but it was called on an input with incompatible shape (1, 256, 512).
<tf.Tensor: shape=(1, 1000), dtype=float32, numpy= array([[-1.39374673e-01, 4.31531779e-02, -1.53121054e-01, -3.51662904e-01, -3.37812826e-02, 3.17848653e-01, -4.65588421e-02, 4.70395952e-01, -6.09455407e-02, -1.43654376e-01, -3.59552205e-01, 3.16246860e-02, -6.62514567e-01, 3.91173363e-01, 2.92431176e-01, -2.87992269e-01, 6.00754544e-02, 7.93858349e-01, 3.38012338e-01, 2.68774062e-01, -1.86528161e-01, 1.69969738e-01, -2.97959864e-01, 2.96187326e-02, -1.74928814e-01, 6.41525924e-01, -8.11314806e-02, 5.01706600e-01, -2.15092361e-01, 2.04202518e-01, -1.81051999e-01, 2.45870650e-02, 6.36546135e-01, 5.15514314e-02, 3.81774753e-02, -2.27998853e-01, -2.56112874e-01, 4.68694307e-02, 7.26741850e-02, 1.84986398e-01, -2.26355925e-01, 3.73162359e-01, 3.63768458e-01, 6.28269687e-02, -1.89601332e-01, -3.82347256e-01, 5.48165977e-01, 3.09483051e-01, 1.94468006e-01, -1.12437233e-01, -4.31880355e-03, -2.95652300e-01, 2.33600810e-01, 4.00204025e-02, -1.64881334e-01, -3.32493149e-02, -5.13644397e-01, 7.85257295e-03, 1.04306620e-02, 3.76613885e-01, 1.85067013e-01, 1.16409957e-01, 3.44758034e-02, -6.33655339e-02, -1.62687391e-01, 4.65367436e-02, 1.58845410e-02, -4.99227643e-01, 1.94744155e-01, -1.07896589e-01, 1.95781112e-01, 1.83953866e-01, -2.01287299e-01, -2.39160016e-01, 8.62206221e-02, 3.51260230e-02, -1.97630912e-01, -1.58469766e-01, 5.00797391e-01, 3.06986123e-01, 1.88668683e-01, -2.71804243e-01, -7.20421493e-01, 6.69750810e-01, 2.01444961e-02, 4.80618000e-01, 1.03888586e-02, -3.16919595e-01, 3.20896864e-01, 3.50972176e-01, 1.27981722e-01, 5.79857409e-01, -2.08498567e-01, 4.30542469e-01, -1.39695659e-01, -1.52401224e-01, 7.16910586e-02, 3.39708805e-01, 1.03769511e-01, 1.85355932e-01, -1.81390166e-01, -1.41283184e-01, -3.77341330e-01, 9.46165174e-02, 2.27482855e-01, 4.15615231e-01, -3.83695401e-02, 2.60267019e-01, -3.00500691e-01, 7.30610847e-01, 4.19252694e-01, 2.73295552e-01, -1.67435169e-01, 5.99668771e-02, -4.01462242e-03, 4.03192461e-01, -3.40019494e-01, -4.40539747e-01, 1.18176006e-01, 2.64658749e-01, 7.58711547e-02, 7.74977133e-02, -2.83844471e-01, -4.32013065e-01, 2.93614626e-01, 1.05380170e-01, 8.35331827e-02, 1.51694685e-01, 3.43247652e-01, -1.67037264e-01, 3.09556305e-01, -6.33196980e-02, 1.57099575e-01, 4.60717827e-02, 2.86227316e-01, 2.61785805e-01, 4.47926968e-01, 5.30266836e-02, -4.82489243e-02, -6.63012683e-01, -4.20390517e-01, -1.37814283e-02, 2.23231897e-01, -2.24510014e-01, 4.15768713e-01, 2.19636783e-03, 1.02294967e-01, -2.25631863e-01, 5.22999704e-01, 1.69877157e-01, 1.15898296e-01, -9.24059674e-02, -3.99935782e-01, 2.19179928e-01, 5.15179560e-02, 4.99276191e-01, 6.18401915e-02, 1.38877690e-01, -7.17071816e-03, -3.21008921e-01, -9.24824774e-02, 6.04132488e-02, 2.94439256e-01, 1.39959753e-01, -1.24456465e-01, 2.81383276e-01, 5.57076558e-02, 5.44974566e-01, 3.27638626e-01, -1.93983629e-01, 2.35834539e-01, -2.15670392e-01, -2.11405709e-01, -3.14481407e-01, -1.45794973e-01, -1.44548528e-02, -1.98037446e-01, -6.15128987e-02, -2.31103092e-01, -2.07961261e-01, 3.40488940e-01, 9.08113942e-02, 5.57350777e-02, 5.60799800e-02, -1.53536022e-01, 1.85122401e-01, 3.04128617e-01, 2.99581558e-01, 8.27941835e-01, -4.22925442e-01, 1.35673314e-01, -1.30322516e-01, -3.20844233e-01, -1.38358213e-02, -3.90420377e-01, -5.84380627e-02, 3.12111646e-01, 2.88683265e-01, -1.34648889e-01, -8.06282535e-02, 5.41494265e-02, -4.20905650e-01, -2.35386714e-02, -1.14319824e-01, 6.93330728e-03, 5.59819818e-01, -8.96292478e-02, -1.98783100e-01, -4.17745374e-02, -2.56369859e-01, 2.36144096e-01, 7.14011863e-02, 1.55234799e-01, 3.55909556e-01, -2.07339242e-01, 3.90996307e-01, 1.35363206e-01, 1.59096763e-01, 3.45866859e-01, 3.47489953e-01, -4.43243310e-02, -1.45733133e-01, -3.13785166e-01, -2.34585300e-01, 3.51925224e-01, -1.07587822e-01, 3.42497736e-01, 5.88689923e-01, -3.61828983e-01, 3.97998840e-02, 2.39242569e-01, -4.96406019e-01, 7.17139721e-01, 2.43134901e-01, -1.04134083e-01, 3.19457173e-01, -3.28325063e-01, -1.17297843e-01, -3.27414364e-01, 3.92565429e-01, -1.90309167e-01, -3.86757463e-01, 3.93882632e-01, -3.17108214e-01, 2.82080650e-01, 1.34829842e-02, -3.25024538e-02, 1.40913770e-01, -1.57339796e-01, -5.28253317e-01, -1.12747863e-01, 3.20119292e-01, -4.74641681e-01, 1.89366966e-01, -1.73700973e-02, -2.17728838e-01, -7.22827554e-01, 3.49197946e-02, 4.20381159e-01, -2.37658858e-01, -4.93453860e-01, -5.23625128e-02, 7.29674175e-02, -1.71682566e-01, 1.90169454e-01, -1.24728210e-01, -1.24557026e-01, -9.44580063e-02, 4.17537875e-02, 2.41064414e-01, -2.00338110e-01, 1.19874224e-01, 1.88862473e-01, 6.10334873e-01, 4.00786579e-01, 7.81336799e-03, 3.75293285e-01, 3.83760899e-01, 5.92427135e-01, -2.00049773e-01, -6.98532462e-02, 1.45429134e-01, -9.21126306e-02, 5.16721308e-01, -4.11183864e-01, 2.82336414e-01, 2.08831370e-01, 1.00385897e-01, -5.15689813e-02, -4.90657091e-01, -4.54751879e-01, 4.13326696e-02, -1.96888953e-01, -1.89246207e-01, -3.23240221e-01, -1.91369846e-01, 1.47290155e-01, 7.11627364e-01, -4.29548264e-01, 2.90306926e-01, 4.41176891e-02, 1.54406711e-01, -2.47187674e-01, -1.68348879e-01, 1.39671057e-01, 3.55488181e-01, -1.43618375e-01, -5.18071949e-02, -2.91463614e-01, -6.97522685e-02, -2.33124197e-01, -1.83870807e-01, -1.89410180e-01, 4.83478457e-01, -6.18100613e-02, -2.16736794e-01, -4.19818401e-01, -4.49819207e-01, -1.25412047e-01, 2.05574334e-01, 4.61487621e-01, 3.06421310e-01, 9.87505913e-02, 4.21257257e-01, 3.26480120e-01, -1.66395858e-01, 1.14363618e-01, -1.90148368e-01, -1.54296935e-01, 5.82169294e-01, 9.70887467e-02, 2.87855938e-02, -3.24695349e-01, -1.02541007e-01, -4.56712767e-02, 4.06473204e-02, 7.89633840e-02, 1.26656070e-01, 1.83350086e-01, -2.66351759e-01, -1.32404685e-01, 3.00858021e-01, -3.24272931e-01, 2.09546149e-01, 1.28848121e-01, -6.11082494e-01, -8.84952694e-02, 1.18353516e-01, 4.94286180e-01, -6.67403489e-02, -4.72582400e-01, -3.13750833e-01, 1.12829849e-01, -3.23282182e-02, -7.82965720e-02, -1.59595072e-01, 1.11284688e-01, 2.90705264e-01, 1.14993915e-01, -4.64175791e-02, -1.19231552e-01, 1.01542920e-02, 4.02038813e-01, 1.65584117e-01, -6.80075347e-01, -3.35612267e-01, -3.17556828e-01, -3.11913073e-01, 5.01659095e-01, -3.10935676e-01, 4.54059392e-01, -4.61332709e-01, 1.82200849e-01, -2.69360065e-01, 4.69077826e-01, -1.51925385e-02, -3.77240598e-01, 8.79387408e-02, 2.48235166e-01, -1.21140905e-01, 2.44924366e-01, -3.23994935e-01, -1.69452801e-01, 1.15887702e-01, -3.02240908e-01, -7.92618394e-02, 1.71691671e-01, 7.79181868e-02, -8.83029476e-02, 1.10128388e-01, -5.36768958e-02, -1.96407139e-01, 1.45159274e-01, -2.60702163e-01, 5.07635809e-02, 3.70709330e-01, 6.06988817e-02, 5.54320961e-03, 3.81664008e-01, -4.05077308e-01, -1.79363400e-01, -6.83552176e-02, 1.13391437e-01, 5.85146248e-01, 3.22391689e-01, 2.59608347e-02, -5.60023487e-01, 8.43722299e-02, -1.54165015e-01, -3.74390393e-01, -4.89473641e-02, -3.61364782e-01, 5.31697981e-02, -5.74034080e-02, 1.22127533e-01, -2.02938393e-02, -3.69675696e-01, 3.65398467e-01, 1.36449769e-01, 2.51862139e-01, 7.27644339e-02, -2.33305186e-01, 1.33512970e-02, -2.11269438e-01, 1.04628563e-01, -6.34639040e-02, -4.36841875e-01, -7.05691159e-01, 2.15132311e-01, 9.50373039e-02, -1.70452505e-01, 5.18790960e-01, -5.34318268e-01, -3.66455406e-01, 1.26164258e-02, 2.74535865e-02, 1.26335219e-01, -5.55068731e-01, 2.82479078e-01, -1.50238171e-01, 1.08623020e-02, 1.42439961e-01, 4.13537115e-01, 2.77255535e-01, -3.68762732e-01, 1.38078108e-01, 2.00185508e-01, 6.14921093e-01, -5.20268440e-01, 9.60989073e-02, 7.05100521e-02, 1.34135008e-01, -3.32442485e-03, -6.54012561e-02, -4.44715589e-01, -1.41573653e-01, 1.82895377e-01, 3.96840513e-01, 3.28510970e-01, -2.66219616e-01, 4.21971455e-02, 2.47426350e-02, 1.19422927e-01, 5.32518804e-01, -4.52622101e-02, 7.46015847e-01, 4.86510724e-01, -4.64894235e-01, 5.50654888e-01, -4.76732135e-01, -1.47581518e-01, -7.31957406e-02, 2.25389108e-01, -1.11836784e-01, -3.95443052e-01, -1.78427026e-01, -3.34511787e-01, -2.87579387e-01, -8.56470019e-02, -1.05016686e-01, 6.27059489e-04, -1.94091760e-02, -1.92004561e-01, -7.66281560e-02, -4.50611323e-01, 1.61591589e-01, -2.99081147e-01, -5.02298325e-02, -1.44910678e-01, 4.82477337e-01, 9.38684493e-02, -1.41035080e-01, 4.19399917e-01, -2.44174778e-01, -4.39195991e-01, -3.09505701e-01, 5.26064098e-01, -5.77739060e-01, 2.05683708e-01, 1.29524738e-01, 8.38232040e-02, -2.51536846e-01, -1.28729045e-01, -5.19177377e-01, 4.82074618e-01, -4.93786454e-01, 1.86811715e-01, 3.96433830e-01, -2.01697201e-01, 3.17016691e-01, -4.17190254e-01, 2.08598748e-01, -7.60084018e-02, -3.82225692e-01, -3.64200175e-01, 2.69938856e-01, -2.27530017e-01, -1.10027768e-01, -3.21138650e-01, -3.60218823e-01, -8.57325643e-02, -4.82969314e-01, 1.26327097e-01, -3.03240359e-01, -4.53942060e-01, 3.12218517e-01, -1.93169713e-01, -1.23543836e-01, -3.01637761e-02, -2.84465462e-01, 1.50445238e-01, 2.42485061e-01, 2.23039374e-01, 2.33833000e-01, -4.24486816e-01, 1.01255521e-01, -1.32390112e-01, 5.40789068e-01, 4.76857483e-01, 2.02210724e-01, -2.17955738e-01, -5.40202558e-01, 2.75142431e-01, -1.65139101e-02, -2.53014445e-01, 3.66955012e-01, 3.18237275e-01, -5.12005165e-02, -1.58025056e-01, 2.34969720e-01, -1.41773403e-01, 3.18104029e-02, -1.44107655e-01, -2.98474312e-01, 2.06865728e-01, -1.93894133e-02, 7.75567517e-02, 2.00957395e-02, 1.33141011e-01, 3.34793001e-01, 6.21880770e-01, -6.34625703e-02, -1.76220998e-01, -4.17173207e-02, 5.15204668e-02, 1.94182117e-02, 5.70351720e-01, -4.78247732e-01, -6.39708102e-01, 1.09685153e-01, 2.05113158e-01, 2.67714649e-01, -2.04398096e-01, -2.70210713e-01, -5.68379350e-02, 8.23454559e-03, -4.23012823e-01, -3.01978379e-01, 7.95814767e-02, 1.81825981e-01, -1.29206508e-01, -1.84026539e-01, -6.77715987e-02, 5.72236121e-01, -1.11725166e-01, -4.80347574e-02, -7.20523149e-02, -3.39221299e-01, -2.99829274e-01, 1.43576324e-01, -2.09247813e-01, 2.52318174e-01, 9.07492451e-03, 3.03118557e-01, -3.50638151e-01, -2.63586402e-01, -2.25022614e-01, 5.05648851e-01, -1.24164678e-01, -3.94332856e-01, -2.74197757e-01, 1.99524797e-02, 2.23308936e-01, -3.83734167e-01, -2.18287334e-01, 8.37648094e-01, -1.66794613e-01, -4.81860600e-02, 6.49020135e-01, 5.85696101e-03, 2.01330066e-01, 5.53579867e-01, 1.62776448e-02, -2.53921598e-01, 3.94360423e-02, 6.95686415e-02, -7.96678185e-01, -1.84057474e-01, 9.37657058e-03, -2.22770184e-01, -2.15947598e-01, -2.24173829e-01, 4.82297689e-02, 5.02218865e-02, -8.12590718e-01, -1.71681717e-01, -2.38327444e-01, -2.35075876e-03, 4.59489763e-01, 2.16933191e-01, -2.30081439e-01, -2.46712238e-01, 2.60204911e-01, 2.34705031e-01, 1.71241313e-01, 3.68713826e-01, 7.16418564e-01, 7.98934400e-02, -8.65116268e-02, -6.55194670e-02, -7.58497357e-01, 2.28272349e-01, -2.42997378e-01, 1.76399022e-01, -7.59005547e-02, 3.27689260e-01, -4.16496456e-01, -2.23981917e-01, 2.58610725e-01, 7.63437271e-01, 2.98021764e-01, 5.05489632e-02, -3.78087610e-02, -1.29932225e-01, -1.88455120e-01, -3.43392998e-01, 3.36140573e-01, -9.27922651e-02, 9.28962976e-02, -3.01369935e-01, 7.22793639e-01, -1.27520129e-01, -3.91404688e-01, 7.02113062e-02, 7.11549670e-02, 2.91227967e-01, -1.16298668e-01, -3.79043281e-01, -7.03829080e-02, -1.57330394e-01, -6.17730260e-01, 4.64107275e-01, 1.98179603e-01, -3.54456782e-01, -7.38115162e-02, 3.53867263e-01, -5.96714579e-03, 2.44288966e-01, -2.65573442e-01, 1.23942494e-01, -9.88763720e-02, -5.87727726e-01, 1.02609299e-01, -1.20377131e-01, -1.28743008e-01, -7.48585910e-04, -7.61065632e-02, -2.67297059e-01, -8.61319751e-02, 4.30682421e-01, -5.84111251e-02, -6.99275196e-01, 2.93698043e-01, -3.77903968e-01, -1.64642453e-01, 3.86217266e-01, -1.66024908e-01, 1.39635608e-01, -3.52742106e-01, 5.94762623e-01, 1.25955623e-02, -4.94841278e-01, -4.20436829e-01, 1.01204492e-01, 3.35298598e-01, -1.71364307e-01, 3.61274958e-01, 2.00389117e-01, -3.73007089e-01, -4.51619685e-01, 1.00984320e-01, 2.22714484e-01, -7.47614354e-02, -1.61900409e-02, 6.79381564e-02, 1.61654264e-01, -4.82516348e-01, -4.46167797e-01, -3.61738622e-01, 1.36379562e-02, 2.91475385e-01, 5.41513264e-02, -5.69600686e-02, 8.19412231e-01, 6.61607534e-02, -3.19009513e-01, -3.75345647e-02, 2.91394442e-01, 1.17897019e-01, -5.66751584e-02, -2.25467190e-01, -1.91342652e-01, 3.06696314e-02, -5.98668493e-02, 7.71950483e-02, -2.88201392e-01, -9.67587605e-02, 1.27109423e-01, -2.79885620e-01, 4.98628259e-01, 4.88280237e-01, -1.01545975e-01, -3.20308208e-01, 4.33160007e-01, -3.27747881e-01, 3.69550228e-01, 1.01076633e-01, -9.70141739e-02, -3.11369091e-01, -3.60242128e-02, 4.46086675e-01, 3.31630200e-01, -2.36335233e-01, 2.32523963e-01, -2.35655040e-01, 2.07468435e-01, 4.36958402e-01, 1.03888258e-01, -2.66651273e-01, -2.61964083e-01, 3.77682060e-01, 1.92576200e-01, -1.18898436e-01, -3.51134121e-01, 1.52672037e-01, -9.87239182e-02, 1.80755034e-01, -7.63263404e-02, 3.37078065e-01, -1.42228007e-01, 6.37982786e-01, 3.33390713e-01, -5.75123280e-02, -2.12361693e-01, -6.88721091e-02, -3.85917090e-02, 2.08716288e-01, 8.37813914e-01, -2.62287706e-01, -1.40657544e-01, -9.07708406e-02, 5.97220898e-01, 6.58936143e-01, 2.16033906e-01, 1.00019641e-01, 1.88649185e-02, -1.19933516e-01, -1.77537277e-02, -1.52188972e-01, 1.60484895e-01, -1.26278520e-01, 1.94395572e-01, 2.64843315e-01, -2.51286626e-01, 1.30134434e-01, -3.15465569e-01, -9.43832994e-02, 2.94549018e-01, 2.54773498e-02, -3.24770987e-01, 4.53824550e-01, -2.49125250e-03, 2.60306388e-01, -4.12632763e-01, -3.36699605e-01, -3.69201377e-02, 2.96318829e-01, -2.43278831e-01, 3.16776067e-01, 9.09071267e-02, -3.24963123e-01, 8.54555428e-01, 1.37030259e-01, -2.90688276e-02, -8.93329799e-01, -1.05454028e-03, -1.14102617e-01, 1.48127943e-01, 6.93331212e-02, -2.87225395e-01, 1.30860597e-01, -3.50777090e-01, 1.84443519e-01, 1.53694823e-02, 3.63641411e-01, 6.68380782e-02, 3.72574441e-02, 6.94521517e-02, -2.12884963e-01, 3.81224960e-01, -1.84898674e-01, -4.06394869e-01, -2.13058218e-01, 1.34816140e-01, -2.80967206e-02, 2.35837996e-01, -1.18716046e-01, 3.90881032e-01, 1.79463863e-01, -1.65486977e-01, 6.49498180e-02, 4.74540025e-01, 1.38584599e-01, -1.86568305e-01, 9.07092541e-02, 5.40262908e-02, -1.30711600e-01, -1.92879677e-01, 9.74644870e-02, 2.26235598e-01, -1.66839018e-01, -6.45952299e-03, 3.50733846e-01, 4.23705369e-01, 3.55927885e-01, 1.27213150e-01, -6.41065121e-01, 1.48133725e-01, 1.75222456e-02, 2.70058453e-01, 2.28535593e-01, 1.70457423e-01, -1.70062989e-01, 1.07099779e-01, 2.90344030e-01, -4.76619214e-01, 5.37823796e-01, 1.64537858e-02, 2.41610765e-01, 2.24409923e-01, 1.04356177e-01, -8.18013921e-02, -2.14146808e-01, 2.08547339e-04, -4.26286831e-02, -9.14514810e-02, 1.94162831e-01, 1.58477992e-01, -7.07700849e-03, 2.20134348e-01, 2.04527766e-01, -3.99653167e-02, -2.91102260e-01, -9.38635170e-01, 1.25603318e-01, -2.97070563e-01, 4.77570802e-01, -2.30262473e-01, 1.60622329e-01, -2.73500204e-01, 4.87234116e-01, -2.46751904e-01, 3.76614407e-02, 4.20700014e-02, -2.25379348e-01, 5.19851670e-02, -2.45917320e-01, -4.99076724e-01, 7.08286911e-02, 3.41958441e-02, -4.55582440e-01, -9.69141722e-04, -2.13894635e-01, -7.17190504e-01, -2.05675930e-01, 2.39348173e-01, -1.59510404e-01, -1.72230005e-02, 1.34926125e-01, -6.30380362e-02, -9.44832444e-01, -3.76404494e-01, -2.40258187e-01, 2.23933250e-01, 5.05633414e-01, -4.41244133e-02, -3.84582818e-01, -2.80651510e-01, 3.66725862e-01, -1.52315140e-01, 1.57589853e-01, 4.54077534e-02, 3.72240454e-01, 5.69120906e-02, 1.91994160e-02, 5.33151865e-01, -2.96255141e-01, -3.33045125e-01, -3.58551860e-01, 2.39191830e-01, 4.08876836e-01, -1.64246947e-01, -3.98220390e-01, 2.61079907e-01, -4.56876427e-01, -2.07266077e-01, -4.08364952e-01, -1.44661963e-01, 1.99390590e-01, -1.44154951e-01, 1.99725330e-01, 2.16926821e-02, -1.42498940e-01, 1.39680743e-01, 1.26709029e-01, -3.12406749e-01, -1.56735256e-01, -7.78264999e-01, 2.31878608e-01, -9.79591385e-02, 2.86773205e-01, 4.79403138e-01, -6.06049001e-02, 2.94536322e-01, 1.08515903e-01, -7.26417363e-01, -4.03282829e-02, 4.21732366e-01, -1.73192888e-01, -2.70015746e-02, -1.80975407e-01, -4.33127284e-01, 3.63169909e-01, 7.89618641e-02, 3.66080821e-01, 8.50096419e-02, 2.03604728e-01, 2.45887898e-02, 3.43034059e-01, -1.46917105e-01, -2.04456061e-01, -2.21334219e-01, 7.12871104e-02, 1.30912170e-01, -8.35199803e-02, -1.45277351e-01, 4.18243259e-01, -4.42795455e-02, 4.95260097e-02, -1.39064804e-01, -3.08490932e-01, -3.27835083e-02, -1.41199797e-01, -1.48494206e-02, 1.14273965e-01, -8.50897551e-01, 7.89156258e-02, -1.91359982e-01, -4.69365790e-02, 3.49373281e-01, -1.48513556e-01, 4.21208590e-01, 2.81512022e-01, 3.46860290e-02, 1.33107379e-01, 6.48450851e-02, -1.29784137e-01, -1.03540137e-01, -1.66932438e-02, -3.19537789e-01]], dtype=float32)>
model(img).shape
WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_3_input'), name='dense_3_input', description="created by layer 'dense_3_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_8_input'), name='dense_8_input', description="created by layer 'dense_8_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_13_input'), name='dense_13_input', description="created by layer 'dense_13_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_18_input'), name='dense_18_input', description="created by layer 'dense_18_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_23_input'), name='dense_23_input', description="created by layer 'dense_23_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_28_input'), name='dense_28_input', description="created by layer 'dense_28_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_33_input'), name='dense_33_input', description="created by layer 'dense_33_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_38_input'), name='dense_38_input', description="created by layer 'dense_38_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_43_input'), name='dense_43_input', description="created by layer 'dense_43_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_48_input'), name='dense_48_input', description="created by layer 'dense_48_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_53_input'), name='dense_53_input', description="created by layer 'dense_53_input'"), but it was called on an input with incompatible shape (1, 256, 512). WARNING:tensorflow:Model was constructed with shape (None, 512) for input KerasTensor(type_spec=TensorSpec(shape=(None, 512), dtype=tf.float32, name='dense_58_input'), name='dense_58_input', description="created by layer 'dense_58_input'"), but it was called on an input with incompatible shape (1, 256, 512).
TensorShape([1, 1000])