1. Anuncie Aqui ! Entre em contato fdantas@4each.com.br

[Python] Multimodal Named Entity Recognition

Discussão em 'Python' iniciado por Stack, Outubro 4, 2024 às 20:32.

  1. Stack

    Stack Membro Participativo

    My project is multi modal Named Entity Recognition. I'm trying to create a multimodal BERT. I first create vectors of images using ResNET and vectors of texts using BERT. This is then fed individually to the multimodal BERT. (Inside the BERT there is multi-head attention mechanism, cross-attention and feed forward) this model should then generate a sequence of vectors, a combination of text and image vectors. I also want this model to generate labels. ) and then this sequence of vectors is then fed into a Bi-LSTM just for contextual enhancement and then CRF to generate the labels. I've done code for the multimodal BERT and I'm getting an error

    import os
    import re
    import random
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    from PIL import Image
    from transformers import BertModel, BertTokenizer, AdamW
    from TorchCRF import CRF

    # Prepare the BERT tokenizer and image preprocessing
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Preprocess images for ResNet
    preprocess_image = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Function to load and preprocess an image
    def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = preprocess_image(image).unsqueeze(0) # Add batch dimension
    return image

    # Function to parse the text file
    def parse_text_file(file_path):
    data = {}
    current_imgid = None
    accumulated_sentence = []
    accumulated_labels = []

    with open(file_path, 'r', encoding='utf-8') as f:
    for line in f:
    line = line.strip()
    if line.startswith("IMGID:"):
    if current_imgid and accumulated_sentence:
    # Save previous IMGID's sentence and labels
    data[current_imgid] = {
    'sentence': " ".join(accumulated_sentence),
    'labels': accumulated_labels
    }

    # Start processing new IMGID
    current_imgid = line.split(':')[1].strip()
    accumulated_sentence = []
    accumulated_labels = []

    elif line: # Non-empty line (word-label pairs)
    word, label = line.split()
    if not re.match(r'http?://\S+', word): # Ignore URLs
    accumulated_sentence.append(word)
    accumulated_labels.append(label)

    # Process last IMGID
    if current_imgid and accumulated_sentence:
    data[current_imgid] = {
    'sentence': " ".join(accumulated_sentence),
    'labels': accumulated_labels
    }

    return data

    # Dataset class combining images and text
    class MultimodalDataset(Dataset):
    def __init__(self, image_dir, text_file):
    self.image_dir = image_dir
    self.data = parse_text_file(text_file)
    self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
    return len(self.data)

    def __getitem__(self, idx):
    # Get image ID
    imgid = list(self.data.keys())[idx]
    img_path = os.path.join(self.image_dir, f"{imgid}.jpg")

    # Load and preprocess the image
    image = load_image(img_path)

    # Get text and labels
    text_data = self.data[imgid]
    sentence = text_data['sentence']
    labels = text_data['labels']

    # Tokenize text and convert to BERT format
    inputs = self.tokenizer(sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=128)

    # Pad labels to the same length as inputs
    padded_labels = [labels if i < len(labels) else "O" for i in range(128)] # Replace with "O" for padding
    padded_labels = padded_labels[:128] # Truncate if longer
    padded_labels += ["O"] * (128 - len(padded_labels)) # Pad with "O"

    return {
    'input_ids': inputs['input_ids'].squeeze(0),
    'attention_mask': inputs['attention_mask'].squeeze(0),
    'image': image.squeeze(0),
    'labels': padded_labels # Use padded labels
    }

    # Function to randomly mask tokens in the input text
    def mask_tokens(inputs, tokenizer, mask_prob=0.15):
    labels = inputs.clone()
    rand = torch.rand(inputs.shape)
    mask_arr = (rand < mask_prob) & (inputs != tokenizer.cls_token_id) & \
    (inputs != tokenizer.sep_token_id) & (inputs != tokenizer.pad_token_id)

    inputs[mask_arr] = tokenizer.mask_token_id
    labels[~mask_arr] = -100 # Only compute loss on masked tokens
    return inputs, labels

    # Function to randomly mask regions of the image
    def mask_image(image, mask_prob=0.15):
    masked_image = image.clone()
    rand = torch.rand(image.shape)
    mask_arr = rand < mask_prob
    masked_image[mask_arr] = 0 # Masked regions set to 0
    return masked_image

    # Multimodal BERT Model
    class MultimodalBERT(nn.Module):
    def __init__(self, hidden_dim=768, num_labels=9): # num_labels = number of NER tags
    super(MultimodalBERT, self).__init__()

    # Image and text embedding dimensions (768 for BERT)
    self.image_linear = nn.Linear(768, hidden_dim) # Map image embeddings to the same size as BERT embeddings
    self.text_bert = BertModel.from_pretrained('bert-base-uncased') # BERT model for text embeddings

    # Attention mechanism for cross-modal interaction
    self.multihead_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, dropout=0.1)

    # Feed-forward neural network after attention
    self.feed_forward = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim)
    )

    # Linear layer to map to the number of NER labels
    self.classifier = nn.Linear(hidden_dim, num_labels) # Map hidden_dim to num_labels (e.g., 9 NER tags)

    # CRF layer for named entity recognition
    self.crf = CRF(num_labels, batch_first=True)

    def forward(self, text_inputs, text_attention_mask, image_features, labels=None, mlm_labels=None):
    # Process text using BERT
    text_outputs = self.text_bert(input_ids=text_inputs, attention_mask=text_attention_mask)
    text_embeddings = text_outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]

    # Process image features
    image_embeddings = self.image_linear(image_features) # [batch_size, num_regions (49), hidden_dim]

    # Multimodal cross-attention between image and text
    multimodal_output, _ = self.multihead_attention(text_embeddings.permute(1, 0, 2), # Query: text embeddings
    image_embeddings.permute(1, 0, 2), # Key: image embeddings
    image_embeddings.permute(1, 0, 2)) # Value: image embeddings
    multimodal_output = multimodal_output.permute(1, 0, 2) # Reshape back to [batch_size, seq_len, hidden_dim]

    # Feed forward layer
    output = self.feed_forward(multimodal_output)

    # Classify each token for NER
    emissions = self.classifier(output) # [batch_size, seq_len, num_labels]

    # Predict NER labels using CRF
    if labels is not None:
    log_likelihood = self.crf(emissions, labels, mask=text_attention_mask.bool())
    return -log_likelihood
    else:
    predictions = self.crf.decode(emissions, mask=text_attention_mask.bool())
    return predictions

    # Training Function
    def train_model(model, dataloader, optimizer, num_epochs=10, save_path="multimodal_bert.pth"):
    model.train()

    for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
    text_inputs, attention_masks, images, labels = batch['input_ids'], batch['attention_mask'], batch['image'], batch['labels']

    # Apply MLM masking
    mlm_inputs, mlm_labels = mask_tokens(text_inputs, tokenizer)
    masked_images = mask_image(images)

    # Forward pass
    loss = model(mlm_inputs, attention_masks, masked_images, labels=labels, mlm_labels=mlm_labels)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader)}')

    # Save the model
    torch.save(model.state_dict(), save_path)
    print(f'Model saved to {save_path}')

    # Function to load model
    def load_model(model, load_path="multimodal_bert.pth"):
    model.load_state_dict(torch.load(load_path))
    model.eval()
    print(f'Model loaded from {load_path}')
    return model

    # Main function to run training
    if __name__ == "__main__":
    # Define paths
    image_dir = "C:/Python312/MNER_Proj/images"# Update with your image directory
    text_file = "C:/Python312/MNER_Proj/images_sentences.txt" # Update with your text file path

    # Create dataset and dataloader
    dataset = MultimodalDataset(image_dir, text_file)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    # Instantiate model and optimizer
    model = MultimodalBERT(hidden_dim=768, num_labels=9)
    optimizer = AdamW(model.parameters(), lr=5e-5)

    # Train the model
    train_model(model, dataloader, optimizer, num_epochs=10, save_path="multimodal_bert.pth")```


    I'm getting this error:

    PS C:\Python312\MNER_Proj> python main_2.py
    C:\Python312\MNER_Proj\venv\Lib\site-packages\transformers\optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
    warnings.warn(
    Traceback (most recent call last):
    File "C:\Python312\MNER_Proj\main_2.py", line 228, in <module>
    train_model(model, dataloader, optimizer, num_epochs=10, save_path="multimodal_bert.pth")
    File "C:\Python312\MNER_Proj\main_2.py", line 191, in train_model
    loss = model(mlm_inputs, attention_masks, masked_images, labels=labels, mlm_labels=mlm_labels)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "C:\Python312\MNER_Proj\venv\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "C:\Python312\MNER_Proj\venv\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "C:\Python312\MNER_Proj\main_2.py", line 155, in forward
    image_embeddings = self.image_linear(image_features) # [batch_size, num_regions (49), hidden_dim]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "C:\Python312\MNER_Proj\venv\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "C:\Python312\MNER_Proj\venv\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "C:\Python312\MNER_Proj\venv\Lib\site-packages\torch\nn\modules\linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: mat1 and mat2 shapes cannot be multiplied (2688x224 and 768x768)

    Continue reading...

Compartilhe esta Página