This is the full output when I ran the code for loading model, after installing all of the prerequisites (transformer version=0.3.5) on Colab.
AttributeError: 'BartConfig' object has no attribute 'image_vocab_size' - I wasn't able to find any examples of this error on the net.
UnfilteredStackTrace Traceback (most recent call last)
in ()
2 tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
----> 3 model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
350 # init random models
--> 351 model = cls(config, *model_args, **model_kwargs)
352
/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in init(self, config, input_shape, seed, dtype, **kwargs)
928 module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 929 super().init(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
930
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in init(self, config, module, input_shape, seed, dtype)
105 # randomly initialized parameters
--> 106 random_params = self.init_weights(self.key, input_shape)
107
/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in init_weights(self, rng, input_shape)
953 position_ids,
--> 954 decoder_position_ids,
955 )["params"]
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs)
1122 rngs, *args,
-> 1123 method=method, mutable=mutable, **kwargs)
1124 return v_out
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs)
1090 return self.apply(
-> 1091 {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
1092
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
1059 mutable=mutable, capture_intermediates=capture_intermediates
-> 1060 )(variables, *args, **kwargs, rngs=rngs)
1061
/usr/local/lib/python3.7/dist-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs)
690 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 691 y = fn(root, *args, **kwargs)
692 if mutable is not False:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs)
1311 try:
-> 1312 return fn(module.clone(parent=scope), *args, **kwargs)
1313 finally:
/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py in wrapped_fn(self, *args, **kwargs)
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.name
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in _call_wrapped_method(self, fun, args, kwargs)
592 else:
--> 593 self._try_setup()
594
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in _try_setup(self, shallow)
788 if not shallow:
--> 789 self.setup()
790 finally:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in _call_wrapped_method(self, fun, args, kwargs)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/dalle_mini/model.py in setup(self)
50 self.lm_head = nn.Dense(
---> 51 self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
52 use_bias=False,
/usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py in getattribute(self, key)
236 key = super().getattribute("attribute_map")[key]
--> 237 return super().getattribute(key)
238
UnfilteredStackTrace: AttributeError: 'BartConfig' object has no attribute 'image_vocab_size'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
in ()
1 # set up tokenizer and model
2 tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
----> 3 model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
349
350 # init random models
--> 351 model = cls(config, *model_args, **model_kwargs)
352
353 if from_pt:
/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in init(self, config, input_shape, seed, dtype, **kwargs)
927 ):
928 module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 929 super().init(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
930
931 def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in init(self, config, module, input_shape, seed, dtype)
104
105 # randomly initialized parameters
--> 106 random_params = self.init_weights(self.key, input_shape)
107
108 # save required_params as set
/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in init_weights(self, rng, input_shape)
952 decoder_attention_mask,
953 position_ids,
--> 954 decoder_position_ids,
955 )["params"]
956
/usr/local/lib/python3.7/dist-packages/dalle_mini/model.py in setup(self)
49 self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
50 self.lm_head = nn.Dense(
---> 51 self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
52 use_bias=False,
53 kernel_init=jax.nn.initializers.normal(self.config.init_std),
/usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py in getattribute(self, key)
235 if key != "attribute_map" and key in super().getattribute("attribute_map"):
236 key = super().getattribute("attribute_map")[key]
--> 237 return super().getattribute(key)
238
239 def init(self, **kwargs):
AttributeError: 'BartConfig' object has no attribute 'image_vocab_size'