Update modeling_llama2.py
Browse files- modeling_llama2.py +1 -1
modeling_llama2.py
CHANGED
|
@@ -22,7 +22,7 @@ from transformers.models.llama.modeling_llama import *
|
|
| 22 |
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
from transformers.utils import logging
|
| 24 |
|
| 25 |
-
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 26 |
from .configuration_mplug_owl2 import LlamaConfig
|
| 27 |
|
| 28 |
class MultiwayNetwork(nn.Module):
|
|
|
|
| 22 |
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
from transformers.utils import logging
|
| 24 |
|
| 25 |
+
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
| 26 |
from .configuration_mplug_owl2 import LlamaConfig
|
| 27 |
|
| 28 |
class MultiwayNetwork(nn.Module):
|