Commit 2877f1a6 by 靓靓

upload file

parent 5d2bfa4d
第二期代码资料下载链接<br>
第二期代码资料下载链接<br>
链接:https://pan.baidu.com/s/1x6_rqxB4Y0M7lORY3jWgMg?pwd=DTAI<br>
提取码:DTAI<br>
\ No newline at end of file
预训练模型,数据,以及训练好的模型都在此百度网盘链接里面:
预训练模型,数据,以及训练好的模型都在此百度网盘链接里面:
链接:https://pan.baidu.com/s/1natUeRXSQ1P9GZ84qvgzZA
提取码:bg40
--来自百度网盘超级会员V7的分享
import os.path
import os.path
import torch
#通过修改self.bert_name和self.resnet_name来决定模型的具体结构
class Config(object):
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
self.dropout = 0.3 # 随机失活
self.require_improvement = 2000 # 若超过2000batch效果还没提升,则提前结束训练
self.num_classes = 2 # 类别数无需修改
self.num_epochs = 20 # epoch数
self.batch_size =32 # mini-batch大小,看显存决定
self.pad_size = 128 # 每句话处理成的长度(短填长切)
self.bert_learning_rate = 1e-5 # bert的学习率,minirbt-h256需要用更大的学习率例如1e-4,其他bert模型设置为1e-5较好
self.resnet_learning_rate = 2e-5 #resnet的学习率,最好比bert的学习率略高
self.other_learning_rate = 2e-5#其他层的学习率
self.frac=1#使用数据的比例,因为训练时间长,方便调参使用,1为全部数据,0.1代表十分之一的数据
self.bert_model_path = 'bert_model/minirbt-h256'
#['bert_model/bert-base-chinese','bert_model/chinese-bert-wwm-ext','bert_model/minirbt-h256']
self.bert_name='bert_model/minirbt-h256'#bert类型,三种选择
self.bert_fc=256 #bert全连接的输出维度 bert-base-chinese和chinese-bert-wwm-ext为768,minirbt-h256为256
#['resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152']
self.resnet_name='resnet18'#resnet的种类,5种可选
self.resnet_fc=self.bert_fc#resnet全连接的输出维度,跟bert需要保持一致
self.usesloss=True#是否使用对比学习
if not os.path.exists('model'):
os.makedirs('model')
if self.usesloss:
self.save_path = 'model/'+'S_'+self.bert_name.replace('bert_model/','')+'_'+self.resnet_name+'.ckpt'#保存模型的路径
self.log_dir= './log/'+'S_'+self.bert_name.replace('bert_model/','')+'_'+self.resnet_name#tensorboard日志的路径
else:
self.save_path = 'model/'+self.bert_name.replace('bert_model/','')+'_'+self.resnet_name+'.ckpt'#保存模型的路径
self.log_dir= './log/'+self.bert_name.replace('bert_model/','')+'_'+self.resnet_name#tensorboard日志的路径
import pandas as pd
import pandas as pd
import os
import csv
import numpy as np
#原始数据重新整理
imgs=os.listdir('./data/images')
# print(imgs)
def new_data(path,label,newpath):
len_list=[]
with open(path,'r',encoding='utf-8')as t1:
t1=t1.readlines()
if len(t1)%3==0:
num=int(len(t1)/3)
print('数据列数:',len(t1))
print('数据条数(除以3):', num)
for n in range(num):
#a2为图片名,a3为文本
a1,a2,a3=t1[n*3],t1[n*3+1],t1[n*3+2]
a2=a2.strip()#取出换行符
a3=a3.strip()
text_len=len(a3)
len_list.append(text_len)
a2=a2.split('|')#分割图片
a2=[x.split('/')[-1] for x in a2 if x!='null']#去除空数据并分割出图片路径
a2 = ['./data/images/' + x for x in a2 if x in imgs]
for m in a2:
all_info=m,a3,label
# print(all_info)
with open(newpath,'a',encoding='utf-8',newline='')as f:
writer=csv.writer(f)
writer.writerow(all_info)
else:
print('数据长度不合理')
print('平均句子长度:',np.mean(len_list))
if os.path.exists('./data/train.csv'):
os.remove('./data/train.csv')#如果存在就删除以免重复写入
with open('./data/train.csv', 'a', encoding='utf-8', newline='') as f:#写入列名
writer = csv.writer(f)
writer.writerow(('path','text','label'))
new_data('./data/tweets/train_rumor.txt',1,'./data/train.csv')#训练谣言数据,1表示给谣言数据添加标签1
new_data('./data/tweets/train_nonrumor.txt',0,'./data/train.csv')#训练非谣言数据,0表示给谣言数据添加标签0
if os.path.exists('./data/test.csv'):
os.remove('./data/test.csv')
with open('./data/test.csv', 'a', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(('path', 'text', 'label'))
new_data('./data/tweets/test_rumor.txt',1,'./data/test.csv')#测试谣言数据
new_data('./data/tweets/test_nonrumor.txt',0,'./data/test.csv')#测试非谣言数据
df=pd.read_csv('./data/train.csv',encoding='utf-8')
val_df=df.sample(frac=0.1)#划分验证集
train_df=df.drop(index=val_df.index.to_list())#划分训练集
print('训练集长度:',len(train_df))
print('测试集长度:',len(val_df))
val_df.to_csv('./data/val.csv',encoding='utf-8',index=None)
train_df.to_csv('./data/train.csv',encoding='utf-8',index=None)
\ No newline at end of file
++ "b/02-\347\254\254\344\270\211\346\234\237/\344\273\243\347\240\201\350\265\204\346\226\231/clipfwx-main/shizhanban/log/S_minirbt-h256_resnet18/events.out.tfevents.1717241988.xuexi"
import torch.nn as nn
import torch.nn as nn
from transformers import BertModel
import torch
import torch.nn.functional as F
from resnet_models import resnet18, resnet34, resnet50, resnet101, resnet152 # 明确导入所需模块
class SupConLoss(nn.Module):
def __init__(self, temperature=0.1, scale_by_temperature=True):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.scale_by_temperature = scale_by_temperature
def forward(self, features, labels=None, mask=None):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
features = F.normalize(features, p=2, dim=1)
batch_size = features.shape[0]
# 关于labels参数
if labels is not None and mask is not None: # labels和mask不能同时定义值,因为如果有label,那么mask是需要根据Label得到的
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None: # 如果没有labels,也没有mask,就是无监督学习,mask是对角线为1的矩阵,表示(i,i)属于同一类
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None: # 如果给出了labels, mask根据label得到,两个样本i,j的label相等时,mask_{i,j}=1
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature) # 计算两两样本间点乘相似度
# for numerical stability
#计算每一行的最大值
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
exp_logits = torch.exp(logits)#求自然指数
# 构建mask #对角线为1 其余为0
logits_mask = torch.ones_like(mask).to(device) - torch.eye(batch_size).to(device)
print(logits_mask)
positives_mask = mask * logits_mask
print(positives_mask)
print('*******************')
negatives_mask = 1. - mask
num_positives_per_row = torch.sum(positives_mask, axis=1) # 除了自己之外,正样本的个数 [2 0 2 2]
denominator = torch.sum(
exp_logits * negatives_mask, axis=1, keepdims=True) + torch.sum(
exp_logits * positives_mask, axis=1, keepdims=True)
log_probs = logits - torch.log(denominator)
if torch.any(torch.isnan(log_probs)):
raise ValueError("Log_prob has nan!")
log_probs = torch.sum(
log_probs * positives_mask, axis=1)[num_positives_per_row > 0] / num_positives_per_row[
num_positives_per_row > 0]
# loss
loss = -log_probs
if self.scale_by_temperature:
loss *= self.temperature
loss = loss.mean()
return loss
class Mynet(nn.Module):
def __init__(self,config):
super(Mynet, self).__init__()
self.config=config
resnet_name=self.config.resnet_name#选取resnet种类
if resnet_name=='resnet18':
self.resnet=resnet18(self.config.resnet_fc)
elif resnet_name=='resnet34':
self.resnet=resnet34(self.config.resnet_fc)
elif resnet_name=='resnet50':
self.resnet=resnet50(self.config.resnet_fc)
elif resnet_name=='resnet101':
self.resnet=resnet101(self.config.resnet_fc)
elif resnet_name=='resnet152':
self.resnet=resnet152(self.config.resnet_fc)
self.bert= BertModel.from_pretrained(self.config.bert_name)#bert的种类
self.fc_1 = nn.Linear(self.config.bert_fc+self.config.resnet_fc, self.config.num_classes)
self.drop=nn.Dropout(self.config.dropout)
self.softmax = nn.Softmax(dim=1)
def forward(self,inx):
# BERT
img,tokens,mask=inx
# attention_mask=mask
img=self.resnet(img)
outputs = self.bert(tokens,attention_mask=mask)
#emb (32,128)-(32,768)
pooled_output = outputs[1]
pooled_output=self.drop(pooled_output)
fea=torch.cat([img,pooled_output],1)
#fea=self.drop(fea)
logits = self.fc_1(fea)
logits=self.softmax(logits)
return img,logits#返回的第一个是需要对比的特征,img就为图像特征,fea就为全特征
import torch.nn as nn
import torch.nn as nn
import math
# 3x3卷积的卷积结构
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
# 残差网络中的basicblock结构
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
# inplanes代表输入通道数,planes代表输出通道数。
super(BasicBlock, self).__init__()
# Conv1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
# Conv2
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
# 下采样
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4 # 输出通道数的倍乘
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
# conv1 1x1
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
# conv2 3x3
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
# conv3 1x1
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=768):
# layers=参数列表 block选择不同的类
self.inplanes = 64
super(ResNet, self).__init__()
# 1.conv1
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# 2.conv2_x
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
# 3.conv3_x
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
# 4.conv4_x
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
# 5.conv5_x
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
# 每个blocks的第一个residual结构保存在layers列表中。
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
# 该部分是将每个blocks的剩下residual 结构保存在layers列表中,这样就完成了一个blocks的构造。
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1) # 将输出结果展开为一维向量
x = self.fc(x)
return x
#残差网络
# resnet18
def resnet18(out_fc):
model = ResNet(BasicBlock, [2, 2, 2, 2],out_fc)
return model
# resnet34
def resnet34(out_fc):
model = ResNet(BasicBlock, [3, 4, 6, 3],out_fc)
return model
# resnet50
def resnet50(out_fc):
model = ResNet(Bottleneck, [3, 4, 6, 3],out_fc)
return model
# resnet101
def resnet101(out_fc):
model = ResNet(Bottleneck, [3, 4, 23, 3],out_fc)
return model
# resnet152
def resnet152(out_fc):
model = ResNet(Bottleneck, [3, 8, 36, 3],out_fc)
return model
# coding: UTF-8
# coding: UTF-8
from sklearn import metrics
import time
import numpy as np
from tensorboardX import SummaryWriter
from utils import My_Dataset,get_time_dif
from models import *
from Config import Config
from torch.utils.data import DataLoader
def train(config, model, train_iter, dev_iter, test_iter,writer):
start_time = time.time()
# writer.add_graph(model,input_to_model=((torch.rand(4,256,256,3).to(config.device),
# torch.LongTensor(4,128).to(config.device),
# torch.LongTensor(4,128).to(config.device)),))
model.train()
# print([n for n, p in model.named_parameters() if 'bert' in n])
# print([n for n, p in model.named_parameters() if 'resnet' in n])
optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if 'bert' in n],'lr': config.bert_learning_rate},#包含bert层学习率
{'params': [p for n, p in model.named_parameters() if 'resnet' in n],'lr': config.resnet_learning_rate},#包含resnet层学习率
{'params': [p for n, p in model.named_parameters() if 'resnet' not in n and 'bert' not in n]}]
optimizer = torch.optim.Adam(optimizer_grouped_parameters , lr=config.other_learning_rate) ## 定义优化器
#optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.5, last_epoch=-1)#每2个epoch学习率衰减为原来的一半
total_batch = 0 # 记录进行到多少batch
dev_best_loss = float('inf')
last_improve = 0 # 记录上次验证集loss下降的batch数
flag = False # 记录是否很久没有效果提升
for epoch in range(config.num_epochs):
loss_list=[]#承接每个batch的loss
acc_list=[]
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
for i, (trains, labels) in enumerate(train_iter):
fea,outputs = model(trains)
optimizer.zero_grad()
#print(labels)
if config.usesloss:
bloss = F.cross_entropy(outputs, labels)
sloss=SupConLoss()
sloss=sloss(fea,labels=labels)
loss=(bloss+sloss)/2
else:
loss = F.cross_entropy(outputs, labels)
#print(bloss, sloss, loss)
loss.backward()
optimizer.step()
true = labels.data.cpu()
predic = torch.max(outputs.data, 1)[1].cpu()
train_acc = metrics.accuracy_score(true, predic)
writer.add_scalar('train/loss_iter', loss.item(),total_batch)
writer.add_scalar('train/acc_iter',train_acc,total_batch)
msg1 = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}'
if total_batch%20==0:
print(msg1.format(total_batch, loss.item(), train_acc))
loss_list.append(loss.item())
acc_list.append(train_acc)
total_batch += 1
if total_batch - last_improve > config.require_improvement:
# 验证集loss超过2000batch没下降,结束训练
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break
dev_acc, dev_loss = evaluate(config, model, dev_iter)#model.eval()
if dev_loss < dev_best_loss:
dev_best_loss = dev_loss
torch.save(model.state_dict(), config.save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
time_dif = get_time_dif(start_time)
epoch_loss=np.mean(loss_list)
epoch_acc=np.mean(acc_list)
msg2 = 'EPOCH: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
print(msg2.format(epoch+1,epoch_loss, epoch_acc, dev_loss, dev_acc, time_dif, improve))
writer.add_scalar('train/loss_epoch',epoch_loss, epoch)
writer.add_scalar('train/acc_epoch', epoch_acc, epoch)
writer.add_scalar('val/loss_epoch', dev_loss, epoch)
writer.add_scalar('val/acc_epoch', dev_acc, epoch)
model.train()
scheduler.step()
print('epoch: ', epoch, 'lr: ', scheduler.get_last_lr())
test(config, model, test_iter)
def test(config, model, test_iter):
# 测试函数
model.load_state_dict(torch.load(config.save_path))
model.eval()
start_time = time.time()
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
print(msg.format(test_loss, test_acc))
print("Precision, Recall and F1-Score...") #精确率和召回率以及调和平均数
print(test_report)
print("Confusion Matrix...")
print(test_confusion)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
def evaluate(config, model, data_iter, test=False):
model.eval()
loss_total = 0
predict_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)
with torch.no_grad():
for texts, labels in data_iter:
#print(texts)
fea,outputs = model(texts)
if config.usesloss:
bloss = F.cross_entropy(outputs, labels)
sloss=SupConLoss()
sloss=sloss(fea,labels=labels)
loss=(bloss+sloss)/2
else:
loss = F.cross_entropy(outputs, labels)
loss_total += loss
labels = labels.data.cpu().numpy()
predic = torch.max(outputs.data, 1)[1].cpu().numpy() ###预测结果
# print(outputs)
# print(predic)
# print(labels)
# print('*************************')
labels_all = np.append(labels_all, labels)
predict_all = np.append(predict_all, predic)
acc = metrics.accuracy_score(labels_all, predict_all)
if test:
report = metrics.classification_report(labels_all, predict_all, digits=4)
confusion = metrics.confusion_matrix(labels_all, predict_all)
return acc, loss_total / len(data_iter), report, confusion
return acc, loss_total / len(data_iter)
if __name__ == '__main__':
config = Config()
writer = SummaryWriter(log_dir=config.log_dir)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True # 保证每次结果一样
print("Loading data...")
train_data=My_Dataset('./data/train.csv',config,1)
dev_data = My_Dataset('./data/val.csv',config,1)
test_data = My_Dataset('./data/test.csv',config,1)
train_iter=DataLoader(train_data, batch_size=config.batch_size,shuffle=True) ##训练迭代器
dev_iter = DataLoader(dev_data, batch_size=config.batch_size,shuffle=True) ###验证迭代器
test_iter = DataLoader(test_data, batch_size=config.batch_size,shuffle=True) ###测试迭代器
# 训练
mynet =Mynet(config)
## 模型放入到GPU中去
mynet= mynet.to(config.device)
print(mynet.parameters)
#训练结束后可以注释掉train函数只跑test评估模型性能
#test(config, mynet, test_iter)
train(config, mynet, train_iter, dev_iter, test_iter,writer)
#tensorboard --logdir=log/bert-base-chinese_resnet18 --port=6006
\ No newline at end of file
import torch
import torch
from transformers import BertTokenizer, BertModel
from PIL import Image
from torchvision import transforms
from models import Mynet # 假设你的多模态模型在models.py中定义
from Config import Config # 假设你的配置文件在Config.py中定义
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载图像模型和文本模型
config = Config()
image_model_path = 'model/S_minirbt-h256_resnet18.ckpt' # 图像模型文件路径
# 初始化模型
model = Mynet(config)
model.load_state_dict(torch.load(image_model_path, map_location=device))
model.eval()
model.to(device)
# 加载文本模型和tokenizer
text_tokenizer = BertTokenizer.from_pretrained(config.bert_model_path) # 替换为你的BERT模型路径
# 图像预处理函数
def preprocess_image(image_path):
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image = preprocess(image).unsqueeze(0).to(device)
return image
# 文本预处理函数
def preprocess_text(text):
inputs = text_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
return inputs
# 加载图像和文本
image_path = 'image.png' # 替换为你的图像路径
text = "长相是真美,演技也进步了,5年15部剧却被全网嘲的她在《墨雨云间》翻身" # 替换为你的新闻文本
image_input = preprocess_image(image_path)
text_input = preprocess_text(text)
# 进行推断
with torch.no_grad():
# 提取图像特征和文本特征
img_features, logits = model((image_input, text_input['input_ids'], text_input['attention_mask']))
# 计算分类概率
probs = torch.softmax(logits, dim=-1).cpu().numpy()
# 打印预测概率
print("预测概率:", probs)
# 假设类别标签为['真实新闻', '虚假新闻']
labels = ['真实新闻', '虚假新闻']
predicted_label = labels[torch.argmax(logits).item()]
print("预测标签:", predicted_label)
import os
import os
import torch
from transformers import BertTokenizer
from PIL import Image
from torchvision import transforms
from models import Mynet # 假设你的多模态模型在models.py中定义
from Config import Config # 假设你的配置文件在Config.py中定义
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载图像模型和文本模型
config = Config()
image_model_path = 'model/S_minirbt-h256_resnet18.ckpt' # 图像模型文件路径
# 初始化模型
model = Mynet(config)
model.load_state_dict(torch.load(image_model_path, map_location=device))
model.eval()
model.to(device)
# 加载文本模型和tokenizer
text_tokenizer = BertTokenizer.from_pretrained(config.bert_model_path) # 替换为你的BERT模型路径
# 图像预处理函数
def preprocess_image(image_path):
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image = preprocess(image).unsqueeze(0).to(device)
return image
# 文本预处理函数
def preprocess_text(text):
inputs = text_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
return inputs
# 批量预测函数
def batch_predict(image_paths, texts):
results = []
for image_path, text in zip(image_paths, texts):
image_input = preprocess_image(image_path)
text_input = preprocess_text(text)
# 进行推断
with torch.no_grad():
# 提取图像特征和文本特征
img_features, logits = model((image_input, text_input['input_ids'], text_input['attention_mask']))
# 计算分类概率
probs = torch.softmax(logits, dim=-1).cpu().numpy()
# 假设类别标签为['真实新闻', '虚假新闻']
labels = ['真实新闻', '虚假新闻']
predicted_label = labels[torch.argmax(logits).item()]
results.append({
"image_path": image_path,
"text": text,
"predicted_label": predicted_label,
"probabilities": probs
})
return results
# 读取文件
image_file_path = 'images.txt'
text_file_path = 'texts.txt'
with open(image_file_path, 'r') as f:
image_paths = [line.strip() for line in f.readlines()]
with open(text_file_path, 'r') as f:
texts = [line.strip() for line in f.readlines()]
# 确保图像路径和文本数量一致
assert len(image_paths) == len(texts), "图像文件数量和文本数量不一致"
# 批量预测
results = batch_predict(image_paths, texts)
# 打印结果
for result in results:
print(f"图像路径: {result['image_path']}")
print(f"文本: {result['text']}")
print(f"预测标签: {result['predicted_label']}")
print(f"预测概率: {result['probabilities']}")
print("---------")
import torch
import torch
import numpy as np
import time
from datetime import timedelta
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
from torch.utils.data import DataLoader
import cv2
from transformers import BertTokenizer
from Config import Config
# a.通过词典导入分词器
#"bert-base-chinese"
#bert_model/chinese-bert-wwm-ext
#tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
class My_Dataset(Dataset):
def __init__(self,path,config,iftrain):#### 读取数据集
self.config=config
#启用训练模式,加载数据和标签
#D:\A_sell_project\cv\多模态虚假新闻分类\data\train.csv
self.iftrain=iftrain
df = pd.read_csv(path).sample(frac=self.config.frac)
self.img_path = df['path'].to_list() #[img]
self.text = df['text'].to_list()
self.tokenizer = BertTokenizer.from_pretrained(self.config.bert_name)
#启用训练模式,加载数据和标签
if self.iftrain==1:
self.labels=df['label'].to_list()#[label]
def __getitem__(self, idx):
img=Image.open(self.img_path[idx])
img=img.convert("RGB")
img=np.array(img)
img=cv2.resize(img,(224,224))#
img = img / 255.
img=np.transpose(img,(2,0,1))
img = torch.tensor(img, dtype=torch.float32)
text=self.text[idx]
try:
len(text)#部分文本是nan
except:
text=''
text=self.tokenizer(text=text, add_special_tokens=True,
max_length=self.config.pad_size, # 最大句子长度
padding='max_length', # 补零到最大长度
truncation=True)
#print(text)
# 中文-英文 (t1[我 吃 饭],t2[i eat food]) [[0,0,0,0,0],[1,1,1,1,1]]
#text 三个部分 token_type_ids(句子对 中文句子 英文句子)
input_id= torch.tensor(text['input_ids'], dtype=torch.long)
attention_mask = torch.tensor(text['attention_mask'], dtype=torch.long)#可用可不用
#
if self.iftrain==1:
label=int(self.labels[idx])
label = torch.tensor(label, dtype=torch.long)
return (img.to(self.config.device),input_id.to(self.config.device),attention_mask.to(self.config.device)),label.to(self.config.device)
else:
return (img.to(self.config.device),input_id.to(self.config.device),attention_mask.to(self.config.device))
def __len__(self):
return len(self.img_path)#总数据长度
def get_time_dif(start_time):
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
if __name__=='__main__':
config=Config()
train_data=My_Dataset('./data/train.csv',config,1)
train_iter = DataLoader(train_data, batch_size=32)
n=0
for a,b in train_iter:
n=n+1
print(n,b.shape)
#print(y)
print('************')
\ No newline at end of file
使用use即可指令生成图像
使用use即可指令生成图像
模型链接
链接:https://pan.baidu.com/s/1pCeuhqQ-rpdP2LV2X7eBGw
提取码:55ms
--来自百度网盘超级会员V7的分享
演示视频
链接:https://pan.baidu.com/s/164NbCXTEqKUhghh2kAcT2w
提取码:d9na
--来自百度网盘超级会员V7的分享
* linguist-vendored
* linguist-vendored
*.py linguist-vendored=false
**/__pycache__/
**/__pycache__/
**/.cache/
**/*.pkl
**/.DS*
**/*.pt
*.mlpackage
**/*.ckpt
.vscode
**/.ipynb_checkpoints
generated.png
**/generated
**/pretrained
**/*.msgpack
*.egg-info/
.idea/
*.egg
dist
build
README
**/.cog
**/cog
Copyright 2022 Brett Kuprel
Copyright 2022 Brett Kuprel
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
# min(DALL·E)
# min(DALL·E)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
&nbsp;
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces%20Demo-blue)](https://huggingface.co/spaces/kuprel/min-dalle)
&nbsp;
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
&nbsp;
[![Discord](https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white)](https://discord.com/channels/823813159592001537/912729332311556136)
[YouTube Walk-through](https://youtu.be/x_8uHX5KngE) by The AI Epiphany
This is a fast, minimal port of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini) (with mega weights). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch.
To generate a 3x3 grid of DALL·E Mega images it takes:
- 55 sec with a T4 in Colab
- 33 sec with a P100 in Colab
- 15 sec with an A10G on Hugging Face
Here's a more detailed breakdown of performance on an A100. Credit to [@technobird22](https://github.com/technobird22) and his [NeoGen](https://github.com/technobird22/NeoGen) discord bot for the graph.
<br />
<img src="https://github.com/kuprel/min-dalle/raw/main/performance.png" alt="min-dalle" width="450"/>
<br />
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
## Install
```bash
$ pip install min-dalle
```
## Usage
Load the model parameters once and reuse the model to generate multiple images.
```python
from min_dalle import MinDalle
model = MinDalle(
models_root='./pretrained',
dtype=torch.float32,
device='cuda',
is_mega=True,
is_reusable=True
)
```
The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. If `is_seamless` is true, the image grid will be tiled in token space not pixel space.
```python
image = model.generate_image(
text='Nuclear explosion broccoli',
seed=-1,
grid_size=4,
is_seamless=False,
temperature=1,
top_k=256,
supercondition_factor=32,
is_verbose=False
)
display(image)
```
<img src="https://github.com/kuprel/min-dalle/raw/main/examples/nuclear_broccoli.jpg" alt="min-dalle" width="400"/>
Credit to [@hardmaru](https://twitter.com/hardmaru) for the [example](https://twitter.com/hardmaru/status/1544354119527596034)
### Saving Individual Images
The images can also be generated as a `FloatTensor` in case you want to process them manually.
```python
images = model.generate_images(
text='Nuclear explosion broccoli',
seed=-1,
grid_size=3,
is_seamless=False,
temperature=1,
top_k=256,
supercondition_factor=16,
is_verbose=False
)
```
To get an image into PIL format you will have to first move the images to the CPU and convert the tensor to a numpy array.
```python
images = images.to('cpu').numpy()
```
Then image $i$ can be coverted to a PIL.Image and saved
```python
image = Image.fromarray(images[i])
image.save('image_{}.png'.format(i))
```
### Progressive Outputs
If the model is being used interactively (e.g. in a notebook) `generate_image_stream` can be used to generate a stream of images as the model is decoding. The detokenizer adds a slight delay for each image. Set `progressive_outputs` to `True` to enable this. An example is implemented in the colab.
```python
image_stream = model.generate_image_stream(
text='Dali painting of WALL·E',
seed=-1,
grid_size=3,
progressive_outputs=True,
is_seamless=False,
temperature=1,
top_k=256,
supercondition_factor=16,
is_verbose=False
)
for image in image_stream:
display(image)
```
<img src="https://github.com/kuprel/min-dalle/raw/main/examples/dali_walle_animated.gif" alt="min-dalle" width="300"/>
### Command Line
Use `image_from_text.py` to generate images from the command line.
```bash
$ python image_from_text.py --text='artificial intelligence' --no-mega
```
<img src="https://github.com/kuprel/min-dalle/raw/main/examples/artificial_intelligence.jpg" alt="min-dalle" width="200"/>
import argparse
import argparse
import os
from PIL import Image
from min_dalle import MinDalle
import torch
parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true')
parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False)
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--text', type=str, default='Dali painting of WALL·E')
parser.add_argument('--seed', type=int, default=-1)
parser.add_argument('--grid-size', type=int, default=1)
parser.add_argument('--image-path', type=str, default='generated')
parser.add_argument('--models-root', type=str, default='pretrained')
parser.add_argument('--top_k', type=int, default=256)
def ascii_from_image(image: Image.Image, size: int = 128) -> str:
gray_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata()
chars = list('.,;/IOX')
chars = [chars[i * len(chars) // 256] for i in gray_pixels]
chars = [chars[i * size: (i + 1) * size] for i in range(size // 2)]
return '\n'.join(''.join(row) for row in chars)
def save_image(image: Image.Image, path: str):
if os.path.isdir(path):
path = os.path.join(path, 'generated.png')
elif not path.endswith('.png'):
path += '.png'
print("saving image to", path)
image.save(path)
return image
def generate_image(
is_mega: bool,
text: str,
seed: int,
grid_size: int,
top_k: int,
image_path: str,
models_root: str,
fp16: bool,
):
model = MinDalle(
is_mega=is_mega,
models_root=models_root,
is_reusable=False,
is_verbose=True,
dtype=torch.float16 if fp16 else torch.float32
)
image = model.generate_image(
text,
seed,
grid_size,
top_k=top_k,
is_verbose=True
)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__':
args = parser.parse_args()
print(args)
generate_image(
is_mega=args.mega,
text=args.text,
seed=args.seed,
grid_size=args.grid_size,
top_k=args.top_k,
image_path=args.image_path,
models_root=args.models_root,
fp16=args.fp16,
)
This source diff could not be displayed because it is too large. You can view the blob instead.
from .min_dalle import MinDalle
from .min_dalle import MinDalle
\ No newline at end of file
import os
import os
from PIL import Image
import numpy
from torch import LongTensor, FloatTensor
import torch
import torch.backends.cudnn, torch.backends.cuda
import json
import requests
from typing import Iterator
from .text_tokenizer import TextTokenizer
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
IMAGE_TOKEN_COUNT = 256
class MinDalle:
def __init__(
self,
models_root: str = 'pretrained',
dtype: torch.dtype = torch.float32,
device: str = None,
is_mega: bool = True,
is_reusable: bool = True,
is_verbose=True
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if is_verbose: print("using device", device)
self.device = device
self.is_mega = is_mega
self.is_reusable = is_reusable
self.dtype = dtype
self.is_verbose = is_verbose
self.text_token_count = 64
self.layer_count = 24 if is_mega else 12
self.attention_head_count = 32 if is_mega else 16
self.embed_count = 2048 if is_mega else 1024
self.glu_embed_count = 4096 if is_mega else 2730
self.text_vocab_count = 50272 if is_mega else 50264
self.image_vocab_count = 16415 if is_mega else 16384
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
dalle_path = os.path.join(models_root, model_name)
vqgan_path = os.path.join(models_root, 'vqgan')
if not os.path.exists(dalle_path): os.makedirs(dalle_path)
if not os.path.exists(vqgan_path): os.makedirs(vqgan_path)
self.vocab_path = os.path.join(dalle_path, 'vocab.json')
self.merges_path = os.path.join(dalle_path, 'merges.txt')
self.encoder_params_path = os.path.join(dalle_path, 'encoder.pt')
self.decoder_params_path = os.path.join(dalle_path, 'decoder.pt')
self.detoker_params_path = os.path.join(vqgan_path, 'detoker.pt')
self.init_tokenizer()
if is_reusable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()
def init_tokenizer(self):
if not (os.path.exists(self.vocab_path) and os.path.exists(self.merges_path)):
self.download_tokenizer()
if self.is_verbose: print("initializing TextTokenizer")
with open(self.vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(self.merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.tokenizer = TextTokenizer(vocab, merges)
def download_tokenizer(self):
if self.is_verbose: print("downloading tokenizer params")
suffix = '' if self.is_mega else '_mini'
_ = requests.get(MIN_DALLE_REPO + 'config.json') # trigger HF download
vocab = requests.get(MIN_DALLE_REPO + 'vocab{}.json'.format(suffix))
merges = requests.get(MIN_DALLE_REPO + 'merges{}.txt'.format(suffix))
with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
with open(self.merges_path, 'wb') as f: f.write(merges.content)
def init_encoder(self):
if not os.path.exists(self.encoder_params_path):
self.download_encoder()
if self.is_verbose: print("initializing DalleBartEncoder")
self.encoder = DalleBartEncoder(
attention_head_count=self.attention_head_count,
embed_count=self.embed_count,
glu_embed_count=self.glu_embed_count,
text_token_count=self.text_token_count,
text_vocab_count=self.text_vocab_count,
layer_count=self.layer_count,
device=self.device
).to(self.dtype).eval()
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
del params
self.encoder = self.encoder.to(device=self.device)
def download_encoder(self):
if self.is_verbose: print("downloading encoder params")
suffix = '' if self.is_mega else '_mini'
params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
def init_decoder(self):
if not os.path.exists(self.decoder_params_path):
self.download_decoder()
if self.is_verbose: print("initializing DalleBartDecoder")
self.decoder = DalleBartDecoder(
image_vocab_count=self.image_vocab_count,
attention_head_count=self.attention_head_count,
embed_count=self.embed_count,
glu_embed_count=self.glu_embed_count,
layer_count=self.layer_count,
device=self.device
).to(self.dtype).eval()
params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False)
del params
self.decoder = self.decoder.to(device=self.device)
def download_decoder(self):
if self.is_verbose: print("downloading decoder params")
suffix = '' if self.is_mega else '_mini'
params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
def init_detokenizer(self):
if not os.path.exists(self.detoker_params_path):
self.download_detokenizer()
if self.is_verbose: print("initializing VQGanDetokenizer")
self.detokenizer = VQGanDetokenizer().eval()
params = torch.load(self.detoker_params_path)
self.detokenizer.load_state_dict(params)
del params
self.detokenizer = self.detokenizer.to(device=self.device)
def download_detokenizer(self):
if self.is_verbose: print("downloading detokenizer params")
params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
def image_grid_from_tokens(
self,
image_tokens: LongTensor,
is_seamless: bool,
is_verbose: bool = False
) -> FloatTensor:
if not self.is_reusable: del self.decoder
torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer()
if is_verbose: print("detokenizing image")
images = self.detokenizer.forward(is_seamless, image_tokens)
if not self.is_reusable: del self.detokenizer
return images
def generate_raw_image_stream(
self,
text: str,
seed: int,
grid_size: int,
progressive_outputs: bool = False,
is_seamless: bool = False,
temperature: float = 1,
top_k: int = 256,
supercondition_factor: int = 16,
is_verbose: bool = False
) -> Iterator[FloatTensor]:
image_count = grid_size ** 2
if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
if len(tokens) > self.text_token_count:
tokens = tokens[:self.text_token_count]
if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
text_tokens = torch.tensor(
text_tokens,
dtype=torch.long,
device=self.device
)
if not self.is_reusable: self.init_encoder()
if is_verbose: print("encoding text tokens")
with torch.cuda.amp.autocast(dtype=self.dtype):
encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder
torch.cuda.empty_cache()
if not self.is_reusable: self.init_decoder()
with torch.cuda.amp.autocast(dtype=self.dtype):
expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices]
attention_mask = text_tokens.not_equal(1)[:, None, None, :]
attention_state = torch.zeros(
size=(
self.layer_count,
image_count * 4,
IMAGE_TOKEN_COUNT,
self.embed_count
),
device=self.device
)
image_tokens = torch.full(
(image_count, IMAGE_TOKEN_COUNT + 1),
2 ** 14 - 1,
dtype=torch.long,
device=self.device
)
if seed > 0: torch.manual_seed(seed)
token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device)
settings = torch.tensor(
[temperature, top_k, supercondition_factor],
dtype=torch.float32,
device=self.device
)
for i in range(IMAGE_TOKEN_COUNT):
torch.cuda.empty_cache()
with torch.cuda.amp.autocast(dtype=self.dtype):
image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens(
settings=settings,
attention_mask=attention_mask,
encoder_state=encoder_state,
attention_state=attention_state,
# prev_tokens=image_tokens[:, :i+1],
# token_index=token_indices[:i+1]
prev_tokens=image_tokens[:, [i]],
token_index=token_indices[[i]]
)
with torch.cuda.amp.autocast(dtype=torch.float32):
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
yield self.image_grid_from_tokens(
image_tokens=image_tokens[:, 1:],
is_seamless=is_seamless,
is_verbose=is_verbose
)
def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
image_stream = self.generate_raw_image_stream(*args, **kwargs)
for image in image_stream:
image = image.to(torch.uint8).to('cpu').numpy()
yield Image.fromarray(image)
def generate_images_stream(self, *args, **kwargs) -> Iterator[FloatTensor]:
image_stream = self.generate_raw_image_stream(*args, **kwargs)
for image in image_stream:
grid_size = kwargs['grid_size']
image = image.view([grid_size * 256, grid_size, 256, 3])
image = image.transpose(1, 0)
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
yield image
def generate_image(self, *args, **kwargs) -> Image.Image:
image_stream = self.generate_image_stream(
*args, **kwargs,
progressive_outputs=False
)
return next(image_stream)
def generate_images(self, *args, **kwargs) -> Image.Image:
images_stream = self.generate_images_stream(
*args, **kwargs,
progressive_outputs=False
)
return next(images_stream)
from .dalle_bart_encoder import DalleBartEncoder
from .dalle_bart_encoder import DalleBartEncoder
from .dalle_bart_decoder import DalleBartDecoder
from .vqgan_detokenizer import VQGanDetokenizer
\ No newline at end of file
from typing import Tuple, List
from typing import Tuple, List
import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor
from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256
class DecoderCrossAttention(AttentionBase):
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(decoder_state)
return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttention(AttentionBase):
def __init__(self, head_count: int, embed_count: int):
super().__init__(head_count, embed_count)
def forward(
self,
decoder_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
token_count = token_index.shape[1]
if token_count == 1:
batch_count = decoder_state.shape[0]
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index[0]] = attn_state_new
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, attention_state
class DecoderLayer(nn.Module):
def __init__(
self,
head_count: int,
embed_count: int,
glu_embed_count: int,
device: str
):
super().__init__()
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = DecoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLU(embed_count, glu_embed_count)
self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
# Self Attention
token_count = token_index.shape[1]
if token_count == 1:
self_attn_mask = self.token_indices <= token_index
self_attn_mask = self_attn_mask[:, None, None, :]
else:
self_attn_mask = (
self.token_indices[None, None, :token_count] <=
token_index[:, :, None]
)
self_attn_mask = self_attn_mask[:, None, :, :]
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
decoder_state, attention_state = self.self_attn.forward(
decoder_state=decoder_state,
attention_state=attention_state,
attention_mask=self_attn_mask,
token_index=token_index
)
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
decoder_state = residual + decoder_state
# Cross Attention
residual = decoder_state
decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
decoder_state = self.encoder_attn.forward(
decoder_state=decoder_state,
encoder_state=encoder_state,
attention_mask=attention_mask
)
decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
decoder_state = residual + decoder_state
# Feed forward
residual = decoder_state
decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, attention_state
class DalleBartDecoder(nn.Module):
def __init__(
self,
image_vocab_count: int,
embed_count: int,
attention_head_count: int,
glu_embed_count: int,
layer_count: int,
device: str
):
super().__init__()
self.layer_count = layer_count
self.embed_count = embed_count
self.image_vocab_count = image_vocab_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayer(
head_count=attention_head_count,
embed_count=embed_count,
glu_embed_count=glu_embed_count,
device=device
)
for _ in range(layer_count)
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
def forward(
self,
attention_mask: BoolTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
prev_tokens: LongTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
image_count = encoder_state.shape[0] // 2
token_index = token_index.unsqueeze(0).repeat(image_count * 2, 1)
prev_tokens = prev_tokens.repeat(2, 1)
decoder_state = self.embed_tokens.forward(prev_tokens)
decoder_state += self.embed_positions.forward(token_index)
decoder_state = self.layernorm_embedding.forward(decoder_state)
for i in range(self.layer_count):
decoder_state, attention_state[i] = self.layers[i].forward(
decoder_state,
encoder_state,
attention_state[i],
attention_mask,
token_index
)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
return logits, attention_state
def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]:
logits, attention_state = self.forward(**kwargs)
image_count = logits.shape[0] // 2
temperature = settings[[0]]
top_k = settings[[1]].to(torch.long)
supercondition_factor = settings[[2]]
logits = logits[:, -1, : 2 ** 14]
logits: FloatTensor = (
logits[:image_count] * (1 - supercondition_factor) +
logits[image_count:] * supercondition_factor
)
logits_sorted, _ = logits.sort(descending=True)
is_kept = logits >= logits_sorted[:, top_k - 1]
logits -= logits_sorted[:, [0]]
logits /= temperature
logits.exp_()
logits *= is_kept.to(torch.float32)
image_tokens = torch.multinomial(logits, 1)[:, 0]
return image_tokens, attention_state
\ No newline at end of file
from typing import List
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
class GLU(nn.Module):
def __init__(self, count_in_out: int, count_middle: int):
super().__init__()
self.gelu = nn.GELU()
self.ln0 = nn.LayerNorm(count_in_out)
self.ln1 = nn.LayerNorm(count_middle)
self.fc0 = nn.Linear(count_in_out, count_middle, bias=False)
self.fc1 = nn.Linear(count_in_out, count_middle, bias=False)
self.fc2 = nn.Linear(count_middle, count_in_out, bias=False)
def forward(self, z: FloatTensor) -> FloatTensor:
z = self.ln0.forward(z)
w = self.fc0.forward(z)
w = self.gelu.forward(w)
v = self.fc1.forward(z)
z = self.ln1.forward(w * v)
z = self.fc2.forward(z)
return z
class AttentionBase(nn.Module):
def __init__(self, head_count: int, embed_count: int):
super().__init__()
self.head_count = head_count
self.embed_count = embed_count
self.k_proj = nn.Linear(embed_count, embed_count, bias=False)
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
def forward(
self,
keys: FloatTensor,
values: FloatTensor,
queries: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
attention_weights: FloatTensor = torch.einsum(
'bqhc,bkhc->bhqk',
queries,
keys
)
attention_weights += attention_bias
attention_weights = torch.softmax(attention_weights, -1)
attention_output: FloatTensor = torch.einsum(
"bhqk,bkhc->bqhc",
attention_weights,
values
)
shape = attention_output.shape[:2] + (self.embed_count,)
attention_output = attention_output.reshape(shape)
attention_output = self.out_proj.forward(attention_output)
return attention_output
class EncoderSelfAttention(AttentionBase):
def forward(
self,
encoder_state: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(encoder_state)
return super().forward(keys, values, queries, attention_mask)
class EncoderLayer(nn.Module):
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
super().__init__()
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = EncoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLU(embed_count, glu_embed_count)
def forward(
self,
encoder_state: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
residual = encoder_state
encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state)
encoder_state = self.self_attn.forward(encoder_state, attention_mask)
encoder_state = self.self_attn_layer_norm.forward(encoder_state)
encoder_state = residual + encoder_state
residual = encoder_state
encoder_state = self.glu.forward(encoder_state)
encoder_state = residual + encoder_state
return encoder_state
class DalleBartEncoder(nn.Module):
def __init__(
self,
layer_count: int,
embed_count: int,
attention_head_count: int,
text_vocab_count: int,
text_token_count: int,
glu_embed_count: int,
device: str
):
super().__init__()
self.text_vocab_count = text_vocab_count
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayer(
embed_count = embed_count,
head_count = attention_head_count,
glu_embed_count = glu_embed_count
)
for _ in range(layer_count)
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
token_indices = torch.arange(text_token_count, device=device)
self.pose_tokens = torch.stack([token_indices] * 2)
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)[:, None, None, :]
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(self.pose_tokens)
)
encoder_state = self.layernorm_embedding.forward(encoder_state)
for layer in self.layers:
encoder_state = layer.forward(encoder_state, attention_mask)
encoder_state = self.final_ln.forward(encoder_state)
return encoder_state
\ No newline at end of file
import torch
import torch
from torch import nn
from torch import FloatTensor, LongTensor
from math import sqrt
class ResnetBlock(nn.Module):
def __init__(self, log2_count_in: int, log2_count_out: int):
super().__init__()
m, n = 2 ** log2_count_in, 2 ** log2_count_out
self.is_middle = m == n
self.norm1 = nn.GroupNorm(2 ** 5, m)
self.conv1 = nn.Conv2d(m, n, 3, padding=1)
self.norm2 = nn.GroupNorm(2 ** 5, n)
self.conv2 = nn.Conv2d(n, n, 3, padding=1)
if not self.is_middle:
self.nin_shortcut = nn.Conv2d(m, n, 1)
def forward(self, x: FloatTensor) -> FloatTensor:
h = x
h = self.norm1.forward(h)
h *= torch.sigmoid(h)
h = self.conv1.forward(h)
h = self.norm2.forward(h)
h *= torch.sigmoid(h)
h = self.conv2(h)
if not self.is_middle:
x = self.nin_shortcut.forward(x)
return x + h
class AttentionBlock(nn.Module):
def __init__(self):
super().__init__()
n = 2 ** 9
self.norm = nn.GroupNorm(2 ** 5, n)
self.q = nn.Conv2d(n, n, 1)
self.k = nn.Conv2d(n, n, 1)
self.v = nn.Conv2d(n, n, 1)
self.proj_out = nn.Conv2d(n, n, 1)
def forward(self, x: FloatTensor) -> FloatTensor:
n, m = 2 ** 9, x.shape[0]
h = x
h = self.norm(h)
k = self.k.forward(h)
v = self.v.forward(h)
q = self.q.forward(h)
k = k.reshape(m, n, -1)
v = v.reshape(m, n, -1)
q = q.reshape(m, n, -1)
q = q.permute(0, 2, 1)
w = torch.bmm(q, k)
w /= n ** 0.5
w = torch.softmax(w, dim=2)
w = w.permute(0, 2, 1)
h = torch.bmm(v, w)
token_count = int(sqrt(h.shape[-1]))
h = h.reshape(m, n, token_count, token_count)
h = self.proj_out.forward(h)
return x + h
class MiddleLayer(nn.Module):
def __init__(self):
super().__init__()
self.block_1 = ResnetBlock(9, 9)
self.attn_1 = AttentionBlock()
self.block_2 = ResnetBlock(9, 9)
def forward(self, h: FloatTensor) -> FloatTensor:
h = self.block_1.forward(h)
h = self.attn_1.forward(h)
h = self.block_2.forward(h)
return h
class Upsample(nn.Module):
def __init__(self, log2_count):
super().__init__()
n = 2 ** log2_count
self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
self.conv = nn.Conv2d(n, n, 3, padding=1)
def forward(self, x: FloatTensor) -> FloatTensor:
x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x)
return x
class UpsampleBlock(nn.Module):
def __init__(
self,
log2_count_in: int,
log2_count_out: int,
has_attention: bool,
has_upsample: bool
):
super().__init__()
self.has_attention = has_attention
self.has_upsample = has_upsample
self.block = nn.ModuleList([
ResnetBlock(log2_count_in, log2_count_out),
ResnetBlock(log2_count_out, log2_count_out),
ResnetBlock(log2_count_out, log2_count_out)
])
if has_attention:
self.attn = nn.ModuleList([
AttentionBlock(),
AttentionBlock(),
AttentionBlock()
])
if has_upsample:
self.upsample = Upsample(log2_count_out)
def forward(self, h: FloatTensor) -> FloatTensor:
for j in range(3):
h = self.block[j].forward(h)
if self.has_attention:
h = self.attn[j].forward(h)
if self.has_upsample:
h = self.upsample.forward(h)
return h
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.conv_in = nn.Conv2d(2 ** 8, 2 ** 9, 3, padding=1)
self.mid = MiddleLayer()
self.up = nn.ModuleList([
UpsampleBlock(7, 7, False, False),
UpsampleBlock(8, 7, False, True),
UpsampleBlock(8, 8, False, True),
UpsampleBlock(9, 8, False, True),
UpsampleBlock(9, 9, True, True)
])
self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 7)
self.conv_out = nn.Conv2d(2 ** 7, 3, 3, padding=1)
def forward(self, z: FloatTensor) -> FloatTensor:
z = self.conv_in.forward(z)
z = self.mid.forward(z)
for i in reversed(range(5)):
z = self.up[i].forward(z)
z = self.norm_out.forward(z)
z *= torch.sigmoid(z)
z = self.conv_out.forward(z)
return z
class VQGanDetokenizer(nn.Module):
def __init__(self):
super().__init__()
vocab_count, embed_count = 2 ** 14, 2 ** 8
self.vocab_count = vocab_count
self.embedding = nn.Embedding(vocab_count, embed_count)
self.post_quant_conv = nn.Conv2d(embed_count, embed_count, 1)
self.decoder = Decoder()
def forward(self, is_seamless: bool, z: LongTensor) -> FloatTensor:
grid_size = int(sqrt(z.shape[0]))
token_count = grid_size * 2 ** 4
if is_seamless:
z = z.view([grid_size, grid_size, 2 ** 4, 2 ** 4])
z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
z = z.flatten().unsqueeze(1)
z = self.embedding.forward(z)
z = z.view((1, token_count, token_count, 2 ** 8))
else:
z = self.embedding.forward(z)
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous()
z = self.post_quant_conv.forward(z)
z = self.decoder.forward(z)
z = z.permute(0, 2, 3, 1)
z = z.clip(0.0, 1.0) * 255
if is_seamless:
z = z[0]
else:
z = z.view([grid_size, grid_size, 2 ** 8, 2 ** 8, 3])
z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
return z
from math import inf
from math import inf
from typing import List, Tuple
from emoji import demojize
class TextTokenizer:
def __init__(self, vocab: dict, merges: List[str]):
self.token_from_subword = vocab
pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def tokenize(self, text: str, is_verbose: bool = False) -> List[int]:
sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>']
text = demojize(text, delimiters=['', ''])
text = text.lower().encode("ascii", errors="ignore").decode()
tokens = [
self.token_from_subword.get(subword, unk_token)
for word in text.split(" ") if len(word) > 0
for subword in self.get_byte_pair_encoding(word, is_verbose)
]
return [cls_token] + tokens + [sep_token]
def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]:
def get_pair_rank(pair: Tuple[str, str]) -> int:
return self.rank_from_pair.get(pair, inf)
subwords = [chr(ord(" ") + 256)] + list(word)
while len(subwords) > 1:
pairs = list(zip(subwords[:-1], subwords[1:]))
pair_to_merge = min(pairs, key=get_pair_rank)
if pair_to_merge not in self.rank_from_pair: break
i = pairs.index(pair_to_merge)
subwords = (
(subwords[:i] if i > 0 else []) +
[subwords[i] + subwords[i + 1]] +
(subwords[i + 2:] if i + 2 < len(subwords) else [])
)
if is_verbose: print(subwords)
return subwords
\ No newline at end of file
build:
build:
cuda: "11.5.1"
gpu: true
python_version: "3.10"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "min-dalle==0.4.5"
- "emoji==1.7.0"
run:
- pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
predict: "predictor.py:ReplicatePredictor"
\ No newline at end of file
from min_dalle import MinDalle
from min_dalle import MinDalle
import tempfile
import string
import torch, torch.backends.cudnn, torch.backends.cuda
from typing import Iterator
from emoji import demojize
from cog import BasePredictor, Path, Input
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
def filename_from_text(text: str) -> str:
text = demojize(text, delimiters=['', ''])
text = text.lower().encode("ascii", errors="ignore").decode()
allowed_chars = string.ascii_lowercase + ' '
text = ''.join(i for i in text.lower() if i in allowed_chars)
text = text[:64]
text = '-'.join(text.strip().split())
if len(text) == 0: text = 'blank'
return text
class ReplicatePredictor(BasePredictor):
def setup(self):
self.model = MinDalle(
is_mega=True,
is_reusable=True,
dtype=torch.float32,
device='cuda'
)
def predict(
self,
text: str = Input(default='Dali painting of WALL·E'),
save_as_png: bool = Input(default=False),
progressive_outputs: bool = Input(default=True),
seamless: bool = Input(default=False),
grid_size: int = Input(ge=1, le=9, default=5),
temperature: float = Input(
ge=0.01,
le=16,
default=4
),
top_k: int = Input(
choices=[2 ** i for i in range(15)],
default=64,
description='Advanced Setting, see Readme below if interested.'
),
supercondition_factor: int = Input(
choices=[2 ** i for i in range(2, 7)],
default=16,
description='Advanced Setting, see Readme below if interested.'
)
) -> Iterator[Path]:
image_stream = self.model.generate_image_stream(
text = text,
seed = -1,
grid_size = grid_size,
progressive_outputs = progressive_outputs,
is_seamless = seamless,
temperature = temperature,
supercondition_factor = float(supercondition_factor),
top_k = top_k,
is_verbose = True
)
i = 0
path = Path(tempfile.mkdtemp())
for image in image_stream:
i += 1
is_final = i == 8 if progressive_outputs else True
ext = 'png' if is_final and save_as_png else 'jpg'
filename = filename_from_text(text)
filename += '' if is_final else '-iter-{}'.format(i)
image_path = path / '{}.{}'.format(filename, ext)
image.save(str(image_path))
yield image_path
\ No newline at end of file
min-dalle
min-dalle
numpy==1.23.0
pillow==9.2.0
requests==2.28.1
import setuptools
import setuptools
# from pathlib import Path
setuptools.setup(
name='min-dalle',
description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.4.11',
author='Brett Kuprel',
author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle',
packages=[
'min_dalle',
'min_dalle.models'
],
license='MIT',
install_requires=[
'torch>=1.11',
'typing_extensions>=4.1',
'numpy>=1.21',
'pillow>=7.1',
'requests>=2.23',
'emoji'
],
keywords = [
'artificial intelligence',
'deep learning',
'text-to-image',
'pytorch'
]
)
\ No newline at end of file
from min_dalle import MinDalle
from min_dalle import MinDalle
import sys
import PIL
import PIL.Image
import PIL.ImageTk
import tkinter
from tkinter import ttk
def regen_root():
global root
global blank_image
global padding_image
root = tkinter.Tk()
root.wm_resizable(False, False)
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 2, 256 * 2), mode="RGB"))
padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA"))
regen_root()
is_mega = None
def set_mega_true_and_destroy():
global is_mega
is_mega = True
root.destroy()
def set_mega_false_and_destroy():
global is_mega
is_mega = False
root.destroy()
frm = ttk.Frame(root, padding=16)
frm.grid()
ttk.Button(frm, text="Mega", command=set_mega_true_and_destroy).grid(column=0, row=0)
ttk.Label(frm, image=padding_image).grid(column=1, row=0)
ttk.Button(frm, text="Mini", command=set_mega_false_and_destroy).grid(column=2, row=0)
root.mainloop()
if is_mega is None:
print("no option selected")
sys.exit(0)
print("is_mega", is_mega)
model = MinDalle(
models_root="./pretrained",
is_mega=is_mega,
is_reusable=True,
is_verbose=True
)
regen_root()
label_image_content = blank_image
sv_prompt = tkinter.StringVar(value="artificial intelligence")
sv_temperature = tkinter.StringVar(value="1")
sv_topk = tkinter.StringVar(value="128")
sv_supercond = tkinter.StringVar(value="16")
bv_seamless = tkinter.BooleanVar(value=False)
def generate():
# check fields
try:
temperature = float(sv_temperature.get())
except:
sv_temperature.set("ERROR")
return
try:
topk = int(sv_topk.get())
except:
sv_topk.set("ERROR")
return
try:
supercond = int(sv_supercond.get())
except:
sv_supercond.set("ERROR")
return
try:
is_seamless = bool(bv_seamless.get())
except:
return
# and continue
global label_image_content
image_stream = model.generate_image_stream(
sv_prompt.get(),
grid_size=2,
seed=-1,
progressive_outputs=True,
is_seamless=is_seamless,
temperature=temperature,
top_k=topk,
supercondition_factor=supercond,
is_verbose=True
)
for image in image_stream:
global final_image
final_image = image
label_image_content = PIL.ImageTk.PhotoImage(image)
label_image.configure(image=label_image_content)
label_image.update()
def save():
final_image.save('generated/out.png')
frm = ttk.Frame(root, padding=16)
frm.grid()
props = ttk.Frame(frm)
# outer structure (hbox)
label_image = ttk.Label(frm, image=blank_image)
label_image.grid(column=0, row=0)
ttk.Label(frm, image=padding_image).grid(column=1, row=0)
props.grid(column=2, row=0)
# inner structure (properties and shit)
# prompt field
ttk.Label(props, text="Prompt:").grid(column=0, row=0)
ttk.Entry(props, textvariable=sv_prompt).grid(column=1, row=0)
#
ttk.Label(props, image=padding_image).grid(column=0, row=1)
# temperature field
ttk.Label(props, text="Temperature:").grid(column=0, row=2)
ttk.Entry(props, textvariable=sv_temperature).grid(column=1, row=2)
#
ttk.Label(props, image=padding_image).grid(column=0, row=3)
# topk field
ttk.Label(props, text="Top-K:").grid(column=0, row=4)
ttk.Entry(props, textvariable=sv_topk).grid(column=1, row=4)
#
ttk.Label(props, image=padding_image).grid(column=0, row=5)
# superconditioning field
ttk.Label(props, text="Supercondition Factor:").grid(column=0, row=6)
ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6)
#
ttk.Label(props, image=padding_image).grid(column=0, row=7)
# seamless
ttk.Label(props, text="Seamless:").grid(column=0, row=8)
ttk.Checkbutton(props, variable=bv_seamless).grid(column=1, row=8)
#
ttk.Label(props, image=padding_image).grid(column=0, row=9)
# buttons
ttk.Button(props, text="Generate", command=generate).grid(column=0, row=10)
ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=10)
ttk.Button(props, text="Save", command=save).grid(column=2, row=10)
root.mainloop()
\ No newline at end of file
import torch
import torch
from PIL import Image
import numpy as np
from min_dalle import MinDalle
import os
import argparse
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='Generate images from text using min-DALL·E.')
parser.add_argument('--text', type=str, required=True, help='The text prompt for generating images')
parser.add_argument('--seed', type=int, default=-1, help='Random seed for reproducibility')
parser.add_argument('--grid_size', type=int, default=4, help='Grid size for the generated image')
parser.add_argument('--is_seamless', action='store_true', help='Generate seamless images')
parser.add_argument('--temperature', type=float, default=1, help='Sampling temperature')
parser.add_argument('--top_k', type=int, default=256, help='Top-k sampling')
parser.add_argument('--supercondition_factor', type=int, default=32, help='Super conditioning factor')
args = parser.parse_args()
# 配置模型参数
model = MinDalle(
models_root='./pretrained',
dtype=torch.float32, # 可以改为 torch.float16 节省GPU内存
device='cuda', # 或 'cpu'
is_mega=True,
is_reusable=True
)
# 生成图像
image = model.generate_image(
text=args.text,
seed=args.seed,
grid_size=args.grid_size,
is_seamless=args.is_seamless,
temperature=args.temperature,
top_k=args.top_k,
supercondition_factor=args.supercondition_factor,
is_verbose=False
)
# 显示图像
image.show()
# 处理文件名,去除不合法字符
filename = "".join([c if c.isalnum() else "_" for c in args.text])
filename = filename[:50] # 限制文件名长度
# 保存图像
image.save(f'{filename}.png')
if __name__ == '__main__':
main()
# python use.py --text "改为要生成的图片" --seed 42 --grid_size 4 --temperature 0.7 --top_k 128 --supercondition_factor 16
# --text:输入的文本描述,用于生成图像(必填)。
# --seed:随机种子,用于生成不同的图像(默认为-1)。
# --grid_size:生成图像网格的大小(默认为4)。
# --is_seamless:是否生成无缝图像(设置此参数则生成无缝图像)。
# --temperature:控制生成图像的多样性(默认为1)。
# --top_k:采样时考虑的最高概率的令牌数量(默认为256)。
# --supercondition_factor:控制生成图像与文本描述的一致性(默认为32)。
\ No newline at end of file
from .min_dalle import MinDalle
from .min_dalle import MinDalle
\ No newline at end of file
++ "b/02-\347\254\254\344\270\211\346\234\237/\350\241\245\345\205\205\350\265\204\346\226\231/.gitkeep"
第三期补充资料下载链接<br>
第三期补充资料下载链接<br>
链接:https://pan.baidu.com/s/1tUbftYm9TQIu8sjr63cN6Q?pwd=DTAI<br>
提取码:DTAI<br>
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment