You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
268 lines
6.7 KiB
268 lines
6.7 KiB
import sys, re, os, traceback |
|
from sets import Set |
|
|
|
def die(*args): |
|
printList(args, sys.stderr) |
|
sys.exit(2) |
|
|
|
def printList(list, file=sys.stdout): |
|
for x in list: |
|
file.write(str(x)) |
|
file.write(' ') |
|
file.write('\n') |
|
|
|
import subprocess |
|
|
|
# Debugging machinery |
|
# ------------------- |
|
|
|
DEBUG = 0 |
|
functionsToDebug = Set() |
|
|
|
def addDebug(func): |
|
if type(func) == str: |
|
functionsToDebug.add(func) |
|
else: |
|
functionsToDebug.add(func.func_name) |
|
|
|
def debug(*args): |
|
if DEBUG: |
|
funcName = traceback.extract_stack()[-2][2] |
|
if funcName in functionsToDebug: |
|
printList(args) |
|
|
|
# Program execution |
|
# ----------------- |
|
|
|
class ProgramError(Exception): |
|
def __init__(self, progStr, error): |
|
self.progStr = progStr |
|
self.error = error |
|
|
|
def __str__(self): |
|
return self.progStr + ': ' + self.error |
|
|
|
addDebug('runProgram') |
|
def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True): |
|
debug('runProgram prog:', str(prog), 'input:', str(input)) |
|
if type(prog) is str: |
|
progStr = prog |
|
else: |
|
progStr = ' '.join(prog) |
|
|
|
try: |
|
if pipeOutput: |
|
stderr = subprocess.STDOUT |
|
stdout = subprocess.PIPE |
|
else: |
|
stderr = None |
|
stdout = None |
|
pop = subprocess.Popen(prog, |
|
shell = type(prog) is str, |
|
stderr=stderr, |
|
stdout=stdout, |
|
stdin=subprocess.PIPE, |
|
env=env) |
|
except OSError, e: |
|
debug('strerror:', e.strerror) |
|
raise ProgramError(progStr, e.strerror) |
|
|
|
if input != None: |
|
pop.stdin.write(input) |
|
pop.stdin.close() |
|
|
|
if pipeOutput: |
|
out = pop.stdout.read() |
|
else: |
|
out = '' |
|
|
|
code = pop.wait() |
|
if returnCode: |
|
ret = [out, code] |
|
else: |
|
ret = out |
|
if code != 0 and not returnCode: |
|
debug('error output:', out) |
|
debug('prog:', prog) |
|
raise ProgramError(progStr, out) |
|
# debug('output:', out.replace('\0', '\n')) |
|
return ret |
|
|
|
# Code for computing common ancestors |
|
# ----------------------------------- |
|
|
|
currentId = 0 |
|
def getUniqueId(): |
|
global currentId |
|
currentId += 1 |
|
return currentId |
|
|
|
# The 'virtual' commit objects have SHAs which are integers |
|
shaRE = re.compile('^[0-9a-f]{40}$') |
|
def isSha(obj): |
|
return (type(obj) is str and bool(shaRE.match(obj))) or \ |
|
(type(obj) is int and obj >= 1) |
|
|
|
class Commit: |
|
def __init__(self, sha, parents, tree=None): |
|
self.parents = parents |
|
self.firstLineMsg = None |
|
self.children = [] |
|
|
|
if tree: |
|
tree = tree.rstrip() |
|
assert(isSha(tree)) |
|
self._tree = tree |
|
|
|
if not sha: |
|
self.sha = getUniqueId() |
|
self.virtual = True |
|
self.firstLineMsg = 'virtual commit' |
|
assert(isSha(tree)) |
|
else: |
|
self.virtual = False |
|
self.sha = sha.rstrip() |
|
assert(isSha(self.sha)) |
|
|
|
def tree(self): |
|
self.getInfo() |
|
assert(self._tree != None) |
|
return self._tree |
|
|
|
def shortInfo(self): |
|
self.getInfo() |
|
return str(self.sha) + ' ' + self.firstLineMsg |
|
|
|
def __str__(self): |
|
return self.shortInfo() |
|
|
|
def getInfo(self): |
|
if self.virtual or self.firstLineMsg != None: |
|
return |
|
else: |
|
info = runProgram(['git-cat-file', 'commit', self.sha]) |
|
info = info.split('\n') |
|
msg = False |
|
for l in info: |
|
if msg: |
|
self.firstLineMsg = l |
|
break |
|
else: |
|
if l.startswith('tree'): |
|
self._tree = l[5:].rstrip() |
|
elif l == '': |
|
msg = True |
|
|
|
class Graph: |
|
def __init__(self): |
|
self.commits = [] |
|
self.shaMap = {} |
|
|
|
def addNode(self, node): |
|
assert(isinstance(node, Commit)) |
|
self.shaMap[node.sha] = node |
|
self.commits.append(node) |
|
for p in node.parents: |
|
p.children.append(node) |
|
return node |
|
|
|
def reachableNodes(self, n1, n2): |
|
res = {} |
|
def traverse(n): |
|
res[n] = True |
|
for p in n.parents: |
|
traverse(p) |
|
|
|
traverse(n1) |
|
traverse(n2) |
|
return res |
|
|
|
def fixParents(self, node): |
|
for x in range(0, len(node.parents)): |
|
node.parents[x] = self.shaMap[node.parents[x]] |
|
|
|
# addDebug('buildGraph') |
|
def buildGraph(heads): |
|
debug('buildGraph heads:', heads) |
|
for h in heads: |
|
assert(isSha(h)) |
|
|
|
g = Graph() |
|
|
|
out = runProgram(['git-rev-list', '--parents'] + heads) |
|
for l in out.split('\n'): |
|
if l == '': |
|
continue |
|
shas = l.split(' ') |
|
|
|
# This is a hack, we temporarily use the 'parents' attribute |
|
# to contain a list of SHA1:s. They are later replaced by proper |
|
# Commit objects. |
|
c = Commit(shas[0], shas[1:]) |
|
|
|
g.commits.append(c) |
|
g.shaMap[c.sha] = c |
|
|
|
for c in g.commits: |
|
g.fixParents(c) |
|
|
|
for c in g.commits: |
|
for p in c.parents: |
|
p.children.append(c) |
|
return g |
|
|
|
# Write the empty tree to the object database and return its SHA1 |
|
def writeEmptyTree(): |
|
tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index' |
|
def delTmpIndex(): |
|
try: |
|
os.unlink(tmpIndex) |
|
except OSError: |
|
pass |
|
delTmpIndex() |
|
newEnv = os.environ.copy() |
|
newEnv['GIT_INDEX_FILE'] = tmpIndex |
|
res = runProgram(['git-write-tree'], env=newEnv).rstrip() |
|
delTmpIndex() |
|
return res |
|
|
|
def addCommonRoot(graph): |
|
roots = [] |
|
for c in graph.commits: |
|
if len(c.parents) == 0: |
|
roots.append(c) |
|
|
|
superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree()) |
|
graph.addNode(superRoot) |
|
for r in roots: |
|
r.parents = [superRoot] |
|
superRoot.children = roots |
|
return superRoot |
|
|
|
def getCommonAncestors(graph, commit1, commit2): |
|
'''Find the common ancestors for commit1 and commit2''' |
|
assert(isinstance(commit1, Commit) and isinstance(commit2, Commit)) |
|
|
|
def traverse(start, set): |
|
stack = [start] |
|
while len(stack) > 0: |
|
el = stack.pop() |
|
set.add(el) |
|
for p in el.parents: |
|
if p not in set: |
|
stack.append(p) |
|
h1Set = Set() |
|
h2Set = Set() |
|
traverse(commit1, h1Set) |
|
traverse(commit2, h2Set) |
|
shared = h1Set.intersection(h2Set) |
|
|
|
if len(shared) == 0: |
|
shared = [addCommonRoot(graph)] |
|
|
|
res = Set() |
|
|
|
for s in shared: |
|
if len([c for c in s.children if c in shared]) == 0: |
|
res.add(s) |
|
return list(res)
|
|
|